package water;
import static org.junit.Assert.*;
import hex.ConfusionMatrix;
import hex.gbm.DTree.TreeModel;
import hex.glm.GLMModel;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import org.junit.*;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import water.Job.JobState;
import water.deploy.*;
import water.fvec.*;
import water.util.*;
public class TestUtil {
private static int _initial_keycnt = 0;
private static Timer _testClassTimer;
protected static void startCloud(String[] args, int nnodes) {
for( int i = 1; i < nnodes; i++ ) {
Node n = new NodeVM(args);
n.inheritIO();
n.start();
}
H2O.waitForCloudSize(nnodes);
}
@BeforeClass public static void setupCloud() {
H2O.main(new String[] {});
_initial_keycnt = H2O.store_size();
assert Job.all().length == 0; // No outstanding jobs
_testClassTimer = new Timer();
}
/** Execute this rule before each test to print test name and test class */
@Rule public TestRule logRule = new TestRule() {
@Override public Statement apply(Statement base, Description description) {
Log.info("###########################################################");
Log.info(" * Test class name: " + description.getClassName());
Log.info(" * Test method name: " + description.getMethodName());
Log.info("###########################################################");
return base;
}
};
@Rule public TestRule timerRule = new TestRule() {
@Override public Statement apply(Statement base, Description description) {
return new TimerStatement(base, description.getClassName()+"#"+description.getMethodName());
};
class TimerStatement extends Statement {
private final Statement base;
private final String tname;
public TimerStatement(Statement base, String tname) { this.base = base; this.tname = tname;}
@Override public void evaluate() throws Throwable {
Timer t = new Timer();
try {
base.evaluate();
} finally {
Log.info("#### TEST "+tname+" EXECUTION TIME: " + t.toString());
}
}
}
};
@AfterClass public static void checkLeakedKeys() {
Log.info("## TEST CLASS EXECUTION TIME (sum over all tests): " + _testClassTimer.toString());
Job[] jobs = Job.all();
for( Job job : jobs ) {
assert job.state != JobState.RUNNING : ("UNFINISHED JOB: " + job.job_key + " " + job.description + ", end_time = " + job.end_time + ", state=" + job.state ); // No pending job
DKV.remove(job.job_key);
}
DKV.remove(Job.LIST); // Remove all keys
if (Log.LOG_KEY!=null) DKV.remove(Log.LOG_KEY); // The job key does not need to be created if the test does not print into logs
DKV.write_barrier();
int leaked_keys = H2O.store_size() - _initial_keycnt;
int nvecs = 0, nchunks = 0, nframes = 0, nmodels = 0, nothers = 0;
if( leaked_keys > 0 ) {
for( Key k : H2O.localKeySet() ) {
Value value = DKV.get(k);
if( value==null ) { leaked_keys--; continue; }
Object o = value.type() != TypeMap.PRIM_B ? value.get() : "byte[]";
// Ok to leak VectorGroups
if( o instanceof Vec.VectorGroup )
leaked_keys--;
else {
try {
System.err.println("Leaked key: " + k + " = " + o);
} catch (NullPointerException t) {
System.err.println("Leaked key: " + k + " = " + o.getClass().getSimpleName() + " with missing data");
} catch (Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
}
if (k.isChunkKey()) nchunks++;
else if (k.isVec()) nvecs++;
else if (o instanceof Frame) nframes++;
else if (o instanceof Model) nmodels++;
else nothers++;
}
}
}
assertTrue("Key leak! #keys(" + leaked_keys + ") = #vecs("+nvecs+")+#chunks("+nchunks+")+#frames("+nframes+")+#nmodels("+nmodels+")+#others("+nothers+")", leaked_keys <= 0);
_initial_keycnt = H2O.store_size();
}
// Stall test until we see at least X members of the Cloud
public static void stall_till_cloudsize(int x) {
stall_till_cloudsize(x, 10000);
}
public static void stall_till_cloudsize(int x, long ms) {
H2O.waitForCloudSize(x, ms);
UKV.put(Job.LIST, new Job.List()); // Jobs.LIST must be part of initial keys
}
public static File find_test_file(String fname) {
// When run from eclipse, the working directory is different.
// Try pointing at another likely place
File file = new File(fname);
if( !file.exists() )
file = new File("target/" + fname);
if( !file.exists() )
file = new File("../" + fname);
if( !file.exists() )
file = new File("../target/" + fname);
if( !file.exists() )
file = null;
return file;
}
// Old VA-style wrappers for tests
public static Key load_test_file(String fname) {
return load_test_file(find_test_file(fname));
}
public static Key load_test_file(File file) { return NFSFileVec.make(file); }
public static Key loadAndParseFile(String keyName, String path) {
Key okey = Key.make(keyName);
ParseDataset2.parse(okey, new Key[]{load_test_file(path)});
return okey;
}
public static Key[] load_test_folder(File folder) {
assert folder.isDirectory();
ArrayList<Key> keys = new ArrayList<Key>();
for( File f : folder.listFiles() ) {
if( f.isFile() )
keys.add(load_test_file(f));
}
Key[] res = new Key[keys.size()];
keys.toArray(res);
return res;
}
public static Key loadAndParseFolder(String keyname, String path) {
Key[] keys = load_test_folder(new File(path));
Arrays.sort(keys);
Key okey = Key.make(keyname);
ParseDataset2.parse(okey, keys);
return okey;
}
// Fluid Vectors
public static Frame parseFromH2OFolder(String path) {
File file = new File(VM.h2oFolder(), path);
return FrameUtils.parseFrame(null, file);
}
public static Frame parseFrame(String path) {
return FrameUtils.parseFrame(null, find_test_file(path));
}
public static Frame parseFrame(File file) {
return FrameUtils.parseFrame(null, file);
}
public static Frame parseFrame(Key okey, String path) {
return FrameUtils.parseFrame(okey, find_test_file(path));
}
public static Frame parseFrame(Key okey, File f) {
return FrameUtils.parseFrame(okey, f);
}
public static Vec vec(int...rows) { return vec(null, null, rows); }
public static Vec vec(String[] domain, int ...rows) { return vec(null, domain, rows); }
public static Vec vec(Key k, String[] domain, int ...rows) {
k = (k==null) ? new Vec.VectorGroup().addVec() : k;
Futures fs = new Futures();
AppendableVec avec = new AppendableVec(k);
NewChunk chunk = new NewChunk(avec, 0);
for( int r = 0; r < rows.length; r++ )
chunk.addNum(rows[r]);
chunk.close(0, fs);
Vec vec = avec.close(fs);
fs.blockForPending();
vec._domain = domain;
return vec;
}
public static Frame frame(String name, Vec vec) { return FrameUtils.frame(name, vec); }
public static Frame frame(String[] names, Vec[] vecs) { return FrameUtils.frame(names, vecs); }
public static Frame frame(String[] names, double[]... rows) { return FrameUtils.frame(names, rows); }
public static void dumpKeys(String msg) {
System.err.println("-->> Store dump <<--");
System.err.println(" " + msg);
System.err.println(" Keys: " + H2O.store_size());
for ( Key k : H2O.localKeySet()) System.err.println(" * " + k);
System.err.println("----------------------");
}
public static String[] ar (String ...a) { return a; }
public static long [] ar (long ...a) { return a; }
public static long[][] ar (long[] ...a) { return a; }
public static int [] ari(int ...a) { return a; }
public static int [][] ar (int[] ...a) { return a; }
public static float [] arf(float ...a) { return a; }
public static double[] ard(double ...a) { return a; }
public static double[][] ard(double[] ...a) { return a; }
// Expanded array
public static double[][] ear (double ...a) {
double[][] r = new double[a.length][1];
for (int i=0; i<a.length;i++) r[i][0] = a[i];
return r;
}
public static void assertCM(long[][] expectedCM, long[][] givenCM) {
Assert.assertEquals("Confusion matrix dimension does not match", expectedCM.length, givenCM.length);
String m = "Expected: " + Arrays.deepToString(expectedCM) + ", but was: " + Arrays.deepToString(givenCM);
for (int i=0; i<expectedCM.length; i++) Assert.assertArrayEquals(m, expectedCM[i], givenCM[i]);
}
public static void assertCMEquals(String msg, ConfusionMatrix a, ConfusionMatrix b) {
Assert.assertEquals(msg + " - Confusion matrix should be of the same size", a._arr.length, b._arr.length);
for (int i=0; i< a._arr.length; i++) {
Assert.assertArrayEquals(msg, a._arr[i], b._arr[i]);
}
}
public static void assertModelEquals(Model a, Model b) {
assertArrayEquals("Model names has to equal!", a._names, b._names);
assertEquals("Model has to contain same number of domains!", a._domains.length, b._domains.length);
for (int i=0; i<a._domains.length; i++) {
assertArrayEquals("Model input column "+i+" has to contain same domain names!", a._domains[i], b._domains[i]);
}
}
public static void assertTreeModelEquals(TreeModel a, TreeModel b) {
assertModelEquals(a,b);
assertEquals("Number of demanded trees should be same!", a.N, b.N);
assertEquals("Number of produced trees should be same!", a.ntrees(), b.ntrees());
assertArrayEquals("All error fields should be same (requiring models build without skipping scoring)!", a.errs, b.errs, 0.00000001);
assertEquals("Models shoudl be of the same type!", a.isClassifier(), b.isClassifier());
if (a.isClassifier()) {
assertEquals("The models should contain the same number of CMs", a.cms.length, b.cms.length);
for (int i=0; i<a.cms.length; i++) {
assertCMEquals(i+"-th CM should be same (requiring models build without skipping scoring)!", a.cms[i], b.cms[i]);
}
}
}
public static void assertModelBinaryEquals(Model a, Model b) {
assertArrayEquals("The serialized models are not binary same!", a.write(new AutoBuffer()).buf(), b.write(new AutoBuffer()).buf());
}
public static void sleep(int msec) {
try { Thread.sleep(msec); } catch (InterruptedException e) {}
}
}