/**
*
*/
package water.api;
import hex.pca.PCAModel;
import hex.*;
import hex.gbm.GBM;
import hex.glm.*;
import hex.pca.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.util.*;
import org.junit.*;
import water.*;
import water.api.LogView.LogDownload;
import water.api.Request.API;
import water.api.RequestArguments.Argument;
import water.api.Upload.PostFile;
/**
* The objective of test is to test a stability of API.
*
* It is filled by REST api calls and their arguments used by Python/R code.
* The test tests if the arguments are published by REST API in Java code.
*
* Note: this is a pure JUnit test, no cloud is launched
*/
public class StableAPITest {
/** Mapping between REST API methods and their attributes used by Python code. */
static Map< Class<? extends Request>, String[]> pyAPI = new HashMap<Class<? extends Request>, String[]>();
/** Mapping between REST API methods and their attributes used by R code. */
static Map< Class<? extends Request>, String[]> rAPI = new HashMap<Class<? extends Request>, String[]>();
/** Test compatibility of defined Python calls with REST API published by Java code. */
@Test
public void testPyAPICompatibility() {
testAPICompatibility("Python API", pyAPI);
}
/** Test compatibility of defined R calls with REST API published by Java code. */
@Test
public void testRAPICompatibility() {
testAPICompatibility("R API", rAPI);
}
/** Test given client APIs calls against REST API published by Java code. */
private void testAPICompatibility(String client, Map<Class<? extends Request>, String[]> api ) {
Map<Class<? extends Request>, String[]> unsupportedParams = new HashMap<Class<? extends Request>, String[]>();
for (Map.Entry<Class<? extends Request>, String[]> apiCall : api.entrySet()) {
String[] unsParams = verifyAPICall(client, apiCall.getKey(), apiCall.getValue());
if (unsParams!=null && unsParams.length > 0) unsupportedParams.put(apiCall.getKey(), unsParams);
}
// Do not fail here now
Assert.assertTrue(f(client, unsupportedParams), unsupportedParams.isEmpty());
}
/**
* Verify given <code>api</code> client's call containing given parameters <code>params</code>.
*/
private <T extends Request> String[] verifyAPICall(String client, Class<T> api, String[] params) {
List<String> unsupportedParams = new ArrayList<String>(5);
List<Field> apiParams = getAllParams(api, Request.class, Request2.class.isAssignableFrom(api) ? Request2FFilter : Request1FFilter) ;
T request_v1 = Request2.class.isAssignableFrom(api) ? null : newInstance(api);
//if (Request2.class.isAssignableFrom(api)) System.err.println(apiParams);
for (String par : params) {
// Handle Request2 API - parameters directly corresponds to REST attributes
if (Request2.class.isAssignableFrom(api)) {
if (!contains(par, apiParams)) unsupportedParams.add(par);
} else if (Request.class.isAssignableFrom(Request.class)) {
// Handle original Request
// - in this case we need to look into Argument itself and search for name exposed as REST attribute
assert request_v1 != null;
if (!supportsArg(request_v1, par, apiParams)) unsupportedParams.add(par);
}
}
return unsupportedParams.isEmpty() ? null : unsupportedParams.toArray(new String[unsupportedParams.size()]);
}
static <T extends Request> T newInstance(Class<T> api) {
Assert.assertTrue("The test should instantiat only Request API not Request2 API", !Request2.class.isAssignableFrom(api));
try {
return api.newInstance();
} catch( Exception e ) {
e.printStackTrace();
Assert.assertTrue("Test should be able to instantiate " + api + " via default ctor!", false);
}
return null;
}
static abstract class FFilter { abstract boolean involve(Field f); }
static FFilter Request1FFilter = new FFilter() { @Override boolean involve(Field f) { return Argument.class.isAssignableFrom(f.getType()); } };
static FFilter Request2FFilter = new FFilter() { @Override boolean involve(Field f) { return contains(API.class, f.getDeclaredAnnotations()); } };
static boolean contains(Class<? extends Annotation> annoType, Annotation[] annotations) {
for (Annotation anno : annotations) if (anno.annotationType().equals(annoType)) return true;
return false;
}
static boolean contains(String s, List<Field> params) {
return find(s,params)!=null;
}
static Field find(String name, List<Field> params) {
for (Field f : params) if (name.equals(f.getName())) return f;
return null;
}
static <T extends Request> boolean supportsArg(T api, String name, List<Field> params) {
// Go through all the fields and take their values and search for JSON arg name
for (Field f : params) {
assert Argument.class.isAssignableFrom(f.getType());
try {
Argument arg = (Argument) f.get(api);
if (name.equals(arg._name)) return true;
} catch( Exception e ) {
}
}
// No matching Java API argument found
return false;
}
private static List<Field> getAllParams(Class<?> startClass, Class<?> parentClass, FFilter ffilter) {
List<Field> params = new ArrayList<Field>(10);
Class<?> cls = startClass;
while (cls!=null && cls!=parentClass) {
Field[] fields = cls.getDeclaredFields();
for (Field f : fields) if (ffilter==null || ffilter.involve(f)) { f.setAccessible(true); params.add(f); }
cls = cls.getSuperclass();
}
return params;
}
// Initialize all required static fields, BUT DO NOT START THE CLOUD
@BeforeClass
static public void initTest() throws SecurityException, NoSuchFieldException, IllegalArgumentException, IllegalAccessException {
H2O.NAME = "Test cloud";
// Add a new item into TYPE_MAP
Field fMap = TypeMap.class.getDeclaredField("MAP");
fMap.setAccessible(true);
Map<String,Integer> map = (Map<String, Integer>) fMap.get(null);
map.put(PCAModel.class.getName(), 1000);
}
/**
* Register all existing Python API calls and used parameters.
*
* All of them were collected from Jenkins commands.log.
*/
@BeforeClass
static public void registerPyAPI() {
regPy(Cancel.class, "key");
regPy(Cloud.class);
regPy(ConfusionMatrix.class, "actual", "predict", "vactual", "vpredict");
regPy(DRFModelView.class);
regPy(DRFProgressPage.class);
regPy(Debug.class);
regPy(DownloadDataset.class, "src_key");
regPy(GBM.class, "cols", "destination_key", "learn_rate", "max_depth", "min_rows", "nbins", "ntrees", "response", "source", "validation");
regPy(GBMModelView.class, "_modelKey");
regPy(GBMProgressPage.class);
regPy(Get.class, "key");
regPy(HTTP404.class);
regPy(HTTP500.class);
regPy(IOStatus.class);
regPy(ImportFiles2.class, "path");
regPy(ImportHdfs.class, "path");
regPy(ImportS3.class, "bucket");
regPy(Inspect2.class, "offset", "src_key");
regPy(JStack.class);
regPy(Jobs.class);
regPy(LogView.class);
regPy(NeuralNet.class, "activation", "cols", "destination_key", "epochs", "hidden", "l2", "rate", "response", "source");
regPy(PCA.class, "destination_key", "source", "standardize", "tolerance");
regPy(PCAScore.class, "source", "model", "destination_key", "num_pc");
regPy(Parse2.class, "destination_key", "header", "source_key");
regPy(PostFile.class, "key"); // PostFile has no key attribute - it is hard-coded in Nano
regPy(Predict.class, "data", "model", "prediction");
regPy(Progress2.class);
regPy(PutValue.class, "key", "value");
regPy(QuantilesPage.class);
regPy(Remove.class, "key");
regPy(RemoveAck.class);
regPy(Shutdown.class);
regPy(StoreView.class, "filter", "offset", "view");
regPy(SummaryPage2.class);
regPy(TestPoll.class, "hoho");
regPy(TestRedirect.class);
regPy(Timeline.class);
regPy(Upload.class);
}
/**
* Used R API extracted from R/h2o-package/R/Internal.R
*/
@BeforeClass
static public void registerRAPI() {
regR(Cloud.class);
regPy(DownloadDataset.class, "src_key");
regR(GBM.class, "destination_key", "source", "response", "cols", "ntrees", "max_depth", "learn_rate", "min_rows", "classification");
regR(GBMModelView.class, "_modelKey");
regR(GLM2.class, "source", "destination_key", "response", "ignored_cols", "family", "n_folds", "alpha", "lambda", "standardize", "tweedie_variance_power");
regR(GLMGridProgress.class, "destination_key");
regR(GLMModelView.class, "_modelKey");
regR(ImportHdfs.class, "path");
regR(Inspect2.class, "src_key");
regR(Jobs.class);
regR(LogDownload.class);
regR(PCA.class, "source", "ignored_cols", "destination_key", "max_pc", "tolerance", "standardize");
regR(PCAScore.class, "source", "model", "destination_key", "num_pc");
regR(Predict.class, "model", "data", "prediction");
regR(Remove.class, "key");
regR(StoreView.class);
}
// Register an API method used by Python
static <T extends Request> void regPy(Class<T> api, String... params) { add(pyAPI, api, params); }
// Register an API method used by R
static <T extends Request> void regR(Class<T> api, String... params) { add(rAPI, api, params); }
static <T extends Request> void add(Map< Class<? extends Request>, String[]> apis, Class<T> api, String[] params) {
if (apis.containsKey(api)) {
String[] regParams = apis.get(api); // already registered params
String[] pars = Arrays.copyOf(regParams, regParams.length + params.length);
System.arraycopy(params, 0, pars, regParams.length, params.length);
apis.put(api, pars);
} else apis.put(api, params);
}
static String f(String client, Map<Class<? extends Request>, String[]> params) {
StringBuilder sb = new StringBuilder(client).append(" uses the following unsupported parameters (arguments are not published by REST API)\n");
for (Map.Entry<Class<? extends Request>, String[]> call : params.entrySet() ) {
sb.append(call.getKey()).append(" : ").append(Arrays.toString(call.getValue())).append('\n');
}
return sb.toString();
}
}