package hex.deepwater;
import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.FrameSplitter;
import hex.Model;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.splitframe.ShuffleSplitFrame;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.junit.*;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.util.FileUtils;
import water.util.Log;
import water.util.StringUtils;
import water.util.TwoDimTable;
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import static hex.deepwater.DeepWaterParameters.Network.*;
import static hex.genmodel.algos.deepwater.DeepwaterMojoModel.createDeepWaterBackend;
public abstract class DeepWaterAbstractIntegrationTest extends TestUtil {
protected BackendTrain backend;
abstract DeepWaterParameters.Backend getBackend();
@BeforeClass
static public void stall() { stall_till_cloudsize(1); }
@BeforeClass
public static void checkBackend() { Assume.assumeTrue(DeepWater.haveBackend()); }
@Before
public void createBackend() throws Exception {
backend = createDeepWaterBackend(getBackend().toString());
Assert.assertTrue(backend != null);
}
@Test
public void memoryLeakTest() {
DeepWaterModel m = null;
Frame tr = null;
int counter=3;
while(counter-- > 0) {
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._network = DeepWaterParameters.Network.vgg;
p._learning_rate = 1e-4;
p._mini_batch_size = 8;
p._train_samples_per_iteration = 8;
p._epochs = 1e-3;
m = new DeepWater(p).trainModel().get();
Log.info(m);
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
}
void trainSamplesPerIteration(int samples, int expected) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 1e-3;
p._epochs = 3;
p._train_samples_per_iteration = samples;
m = new DeepWater(p).trainModel().get();
Assert.assertEquals(expected,m.iterations);
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
@Test public void trainSamplesPerIteration0() { trainSamplesPerIteration(0,3); }
@Test public void trainSamplesPerIteration_auto() { trainSamplesPerIteration(-2,1); }
@Test public void trainSamplesPerIteration_neg1() { trainSamplesPerIteration(-1,3); }
@Test public void trainSamplesPerIteration_32() { trainSamplesPerIteration(32,26); }
@Test public void trainSamplesPerIteration_1000() { trainSamplesPerIteration(1000,1); }
@Test
public void overWriteWithBestModel() {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._epochs = 50;
p._learning_rate = 0.01;
p._momentum_start = 0.5;
p._momentum_stable = 0.5;
p._stopping_rounds = 0;
p._image_shape = new int[]{28,28};
p._network = lenet;
p._problem_type = DeepWaterParameters.ProblemType.image;
// score a lot
p._train_samples_per_iteration = p._mini_batch_size;
p._score_duty_cycle = 1;
p._score_interval = 0;
p._overwrite_with_best_model = true;
m = new DeepWater(p).trainModel().get();
Log.info(m);
Assert.assertTrue(((ModelMetricsMultinomial)m._output._training_metrics).logloss()<2);
} finally {
if (m!=null) m.remove();
if (tr!=null) tr.remove();
}
}
void checkConvergence(int channels, DeepWaterParameters.Network network, int epochs) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._network = network;
p._learning_rate = 1e-3;
p._epochs = epochs;
p._channels = channels;
if (network == DeepWaterParameters.Network.vgg) {
p._mini_batch_size = 8; //~6GB with mxnet
} else if (network == DeepWaterParameters.Network.resnet) {
p._mini_batch_size = 16; //~6GB with mxnet
p._learning_rate = 1e-4;
} else if (network == DeepWaterParameters.Network.alexnet) {
p._mini_batch_size = 128; //~3GB with mxnet
p._learning_rate = 1e-4;
} else {
p._mini_batch_size = 32; //<=6GB with mxnet
}
p._problem_type = DeepWaterParameters.ProblemType.image;
m = new DeepWater(p).trainModel().get();
Log.info(m);
System.out.println("Accuracy " + m._output._training_metrics.cm().accuracy());
Assert.assertTrue(m._output._training_metrics.cm().accuracy()>0.9);
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
@Test public void convergenceInceptionColor() { checkConvergence(3, inception_bn, 150); }
@Test public void convergenceInceptionGrayScale() { checkConvergence(1, inception_bn, 150); }
@Test public void convergenceGoogleNetColor() { checkConvergence(3, DeepWaterParameters.Network.googlenet, 150); }
@Test public void convergenceGoogleNetGrayScale() { checkConvergence(1, DeepWaterParameters.Network.googlenet, 150); }
@Test public void convergenceLenetColor() { checkConvergence(3, lenet, 300); }
@Test public void convergenceLenetGrayScale() { checkConvergence(1, lenet, 150); }
@Test public void convergenceVGGColor() { checkConvergence(3, DeepWaterParameters.Network.vgg, 150); }
@Test public void convergenceVGGGrayScale() { checkConvergence(1, DeepWaterParameters.Network.vgg, 150); }
@Test public void convergenceResnetColor() { checkConvergence(3, resnet, 150); }
@Test public void convergenceResnetGrayScale() { checkConvergence(1, resnet, 150); }
@Test public void convergenceAlexnetColor() { checkConvergence(3, alexnet, 150); }
@Test public void convergenceAlexnetGrayScale() { checkConvergence(1, alexnet, 150); }
//FIXME
@Ignore
@Test
public void reproInitialDistribution() {
final int REPS=3;
double[] values=new double[REPS];
for (int i=0;i<REPS;++i) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 0; //no updates to original weights
p._seed = 1234;
p._epochs = 1; //for some reason, can't use 0 epochs
p._channels = 1;
p._train_samples_per_iteration = 0;
m = new DeepWater(p).trainModel().get();
Log.info(m);
values[i]=((ModelMetricsMultinomial)m._output._training_metrics).logloss();
} finally {
if (m != null) m.delete();
if (tr != null) tr.remove();
}
}
for (int i=1;i<REPS;++i) Assert.assertEquals(values[0],values[i],1e-5*values[0]);
}
@Test
public void reproInitialDistributionNegativeTest() {
final int REPS=3;
double[] values=new double[REPS];
for (int i=0;i<REPS;++i) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 0; //no updates to original weights
p._seed = i;
p._epochs = 1; //for some reason, can't use 0 epochs
p._channels = 1;
p._train_samples_per_iteration = 0;
m = new DeepWater(p).trainModel().get();
Log.info(m);
values[i] = ((ModelMetricsMultinomial)m._output._training_metrics).logloss();
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
for (int i=1;i<REPS;++i) Assert.assertNotEquals(values[0],values[i],1e-5*values[0]);
}
// Pure convenience wrapper
@Ignore @Test public void settingModelInfoAll() {
for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
if (network== DeepWaterParameters.Network.user) continue;
if (network== DeepWaterParameters.Network.auto) continue;
settingModelInfo(network);
}
}
@Test public void settingModelInfoAlexnet() { settingModelInfo(alexnet); }
@Test public void settingModelInfoLenet() { settingModelInfo(lenet); }
@Test public void settingModelInfoVGG() { settingModelInfo(DeepWaterParameters.Network.vgg); }
@Test public void settingModelInfoInception() { settingModelInfo(inception_bn); }
@Test public void settingModelInfoResnet() { settingModelInfo(resnet); }
void settingModelInfo(DeepWaterParameters.Network network) {
DeepWaterModel m1 = null;
DeepWaterModel m2 = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._network = network;
p._mini_batch_size = 2;
p._epochs = 0.01;
// p._learning_rate = 0; //needed pass the test for inception/resnet
p._seed = 1234;
p._score_training_samples = 0;
p._train_samples_per_iteration = p._mini_batch_size;
p._problem_type = DeepWaterParameters.ProblemType.image;
// first model
Job j1 = new DeepWater(p).trainModel();
m1 = (DeepWaterModel)j1.get();
int h1 = Arrays.hashCode(m1.model_info()._modelparams);
m1.doScoring(tr,null,j1._key,m1.iterations,true);
double l1 = m1.loss();
// second model (different seed)
p._seed = 4321;
Job j2 = new DeepWater(p).trainModel();
m2 = (DeepWaterModel)j2.get();
// m2.doScoring(tr,null,j2._key,m2.iterations,true);
// double l2 = m2.loss();
int h2 = Arrays.hashCode(m2.model_info()._modelparams);
// turn the second model into the first model
m2.removeNativeState();
DeepWaterModelInfo mi = IcedUtils.deepCopy(m1.model_info());
m2.set_model_info(mi);
m2.doScoring(tr,null,j2._key,m2.iterations,true);
double l3 = m2.loss();
int h3 = Arrays.hashCode(m2.model_info()._modelparams);
Log.info("Checking assertions for network: " + network);
Assert.assertNotEquals(h1, h2);
Assert.assertEquals(h1, h3);
Assert.assertEquals(l1, l3, 1e-5*l1);
} finally {
if (m1!=null) m1.delete();
if (m2!=null) m2.delete();
if (tr!=null) tr.remove();
}
}
//FIXME
@Ignore
@Test
public void reproTraining() {
final int REPS=3;
double[] values=new double[REPS];
for (int i=0;i<REPS;++i) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 1e-4;
p._seed = 1234;
p._epochs = 1;
p._channels = 1;
p._train_samples_per_iteration = 0;
m = new DeepWater(p).trainModel().get();
Log.info(m);
values[i] = ((ModelMetricsMultinomial)m._output._training_metrics).logloss();
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
for (int i=1;i<REPS;++i) Assert.assertEquals(values[0],values[i],1e-5*values[0]);
}
// Pure convenience wrapper
@Ignore @Test public void deepWaterLoadSaveTestAll() {
for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
if (network== DeepWaterParameters.Network.auto) continue;
if (network== DeepWaterParameters.Network.user) continue;
deepWaterLoadSaveTest(network);
}
}
@Test public void deepWaterLoadSaveTestAlexnet() { deepWaterLoadSaveTest(alexnet); }
@Test public void deepWaterLoadSaveTestLenet() { deepWaterLoadSaveTest(lenet); }
@Test public void deepWaterLoadSaveTestVGG() { deepWaterLoadSaveTest(DeepWaterParameters.Network.vgg); }
@Test public void deepWaterLoadSaveTestInception() { deepWaterLoadSaveTest(inception_bn); }
@Test public void deepWaterLoadSaveTestResnet() { deepWaterLoadSaveTest(resnet); }
void deepWaterLoadSaveTest(DeepWaterParameters.Network network) {
DeepWaterModel m = null;
Frame tr = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._network = network;
p._mini_batch_size = 2;
p._epochs = 0.01;
p._seed = 1234;
p._score_training_samples = 0;
p._train_samples_per_iteration = p._mini_batch_size;
p._problem_type = DeepWaterParameters.ProblemType.image;
m = new DeepWater(p).trainModel().get();
Log.info(m);
Assert.assertTrue(m.model_info()._backend ==null);
int hashCodeNetwork = java.util.Arrays.hashCode(m.model_info()._network);
int hashCodeParams = java.util.Arrays.hashCode(m.model_info()._modelparams);
Log.info("Hash code for original network: " + hashCodeNetwork);
Log.info("Hash code for original parameters: " + hashCodeParams);
// move stuff back and forth
m.removeNativeState();
m.model_info().javaToNative();
m.model_info().nativeToJava();
int hashCodeNetwork2 = java.util.Arrays.hashCode(m.model_info()._network);
int hashCodeParams2 = java.util.Arrays.hashCode(m.model_info()._modelparams);
Log.info("Hash code for restored network: " + hashCodeNetwork2);
Log.info("Hash code for restored parameters: " + hashCodeParams2);
Assert.assertEquals(hashCodeNetwork, hashCodeNetwork2);
Assert.assertEquals(hashCodeParams, hashCodeParams2);
} finally {
if (m!=null) m.delete();
if (tr!=null) tr.remove();
}
}
@Test
public void deepWaterCV() {
DeepWaterModel m = null;
Frame tr = null;
Frame preds = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._network = lenet;
p._nfolds = 3;
p._epochs = 2;
m = new DeepWater(p).trainModel().get();
preds = m.score(p._train.get());
Assert.assertTrue(m.testJavaScoring(p._train.get(),preds,1e-3));
Log.info(m);
} finally {
if (m!=null) m.deleteCrossValidationModels();
if (m!=null) m.delete();
if (tr!=null) tr.remove();
if (preds!=null) preds.remove();
}
}
@Test
public void deepWaterCVRegression() {
DeepWaterModel m = null;
Frame tr = null;
Frame preds = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
for (String col : new String[]{p._response_column}) {
Vec v = tr.remove(col);
tr.add(col, v.toNumericVec());
v.remove();
}
DKV.put(tr);
p._network = lenet;
p._nfolds = 3;
p._epochs = 2;
m = new DeepWater(p).trainModel().get();
preds = m.score(p._train.get());
Assert.assertTrue(m.testJavaScoring(p._train.get(),preds,1e-3));
Log.info(m);
} finally {
if (m!=null) m.deleteCrossValidationModels();
if (m!=null) m.delete();
if (tr!=null) tr.remove();
if (preds!=null) preds.remove();
}
}
// Pure convenience wrapper
@Ignore @Test public void restoreStateAll() {
for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
if (network== DeepWaterParameters.Network.user) continue;
if (network== DeepWaterParameters.Network.auto) continue;
restoreState(network);
}
}
@Test public void restoreStateAlexnet() { restoreState(alexnet); }
@Test public void restoreStateLenet() { restoreState(lenet); }
@Test public void restoreStateVGG() { restoreState(vgg); }
@Test public void restoreStateInception() { restoreState(inception_bn); }
@Test public void restoreStateResnet() { restoreState(resnet); }
public void restoreState(DeepWaterParameters.Network network) {
DeepWaterModel m1 = null;
DeepWaterModel m2 = null;
Frame tr = null;
Frame pred = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._network = network;
p._response_column = "C2";
p._mini_batch_size = 2;
p._train_samples_per_iteration = p._mini_batch_size;
p._learning_rate = 0e-3;
p._seed = 12345;
p._epochs = 0.01;
p._quiet_mode = true;
p._problem_type = DeepWaterParameters.ProblemType.image;
m1 = new DeepWater(p).trainModel().get();
Log.info("Scoring the original model.");
pred = m1.score(tr);
pred.remove(0).remove();
ModelMetricsMultinomial mm1 = ModelMetricsMultinomial.make(pred, tr.vec(p._response_column));
Log.info("Original LL: " + ((ModelMetricsMultinomial) m1._output._training_metrics).logloss());
Log.info("Scored LL: " + mm1.logloss());
pred.remove();
Log.info("Keeping the raw byte[] of the model.");
byte[] raw = new AutoBuffer().put(m1).buf();
Log.info("Removing the model from the DKV.");
m1.remove();
Log.info("Restoring the model from the raw byte[].");
m2 = new AutoBuffer(raw).get();
Log.info("Scoring the restored model.");
pred = m2.score(tr);
pred.remove(0).remove();
ModelMetricsMultinomial mm2 = ModelMetricsMultinomial.make(pred, tr.vec(p._response_column));
Log.info("Restored LL: " + mm2.logloss());
double precision = 1e-5;
Assert.assertEquals(((ModelMetricsMultinomial) m1._output._training_metrics).logloss(), mm1.logloss(), precision*mm1.logloss()); //make sure scoring is self-consistent
Assert.assertEquals(mm1.logloss(), mm2.logloss(), precision*mm1.logloss());
} finally {
if (m1 !=null) m1.delete();
if (m2!=null) m2.delete();
if (tr!=null) tr.remove();
if (pred!=null) pred.remove();
}
}
@Test
public void trainLoop() throws InterruptedException {
int batch_size = 64;
BackendModel m = buildLENET();
float[] data = new float[28*28*1*batch_size];
float[] labels = new float[batch_size];
int count=0;
while(count++<1000) {
Log.info("Iteration: " + count);
backend.train(m, data, labels);
}
}
private BackendModel buildLENET() {
int batch_size = 64;
int classes = 10;
ImageDataSet dataset = new ImageDataSet(28, 28, 1, classes);
RuntimeOptions opts = new RuntimeOptions();
opts.setUseGPU(true);
opts.setSeed(1234);
opts.setDeviceID(0);
BackendParams bparm = new BackendParams();
bparm.set("mini_batch_size", batch_size);
return backend.buildNet(dataset, opts, bparm, classes, "lenet");
}
@Test
public void saveLoop() throws IOException {
BackendModel m = buildLENET();
File f = File.createTempFile("saveLoop", ".tmp");
for(int count=0; count < 3; count++) {
Log.info("Iteration: " + count);
backend.saveParam(m, f.getAbsolutePath());
}
backend.deleteSavedParam(f.getAbsolutePath());
}
@Test
public void predictLoop() {
BackendModel m = buildLENET();
int batch_size = 64;
float[] data = new float[28*28*1*batch_size];
int count=0;
while(count++<3) {
Log.info("Iteration: " + count);
backend.predict(m, data);
}
}
@Test
public void trainPredictLoop() {
int batch_size = 64;
BackendModel m = buildLENET();
float[] data = new float[28*28*1*batch_size];
float[] labels = new float[batch_size];
int count=0;
while(count++<1000) {
Log.info("Iteration: " + count);
backend.train(m, data,labels);
float[] p = backend.predict(m, data);
}
}
@Test
public void scoreLoop() {
DeepWaterParameters p = new DeepWaterParameters();
Frame tr;
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._network = lenet;
p._response_column = "C2";
p._mini_batch_size = 4;
p._train_samples_per_iteration = p._mini_batch_size;
p._learning_rate = 0e-3;
p._seed = 12345;
p._epochs = 0.01;
p._quiet_mode = true;
DeepWater j= new DeepWater(p);
DeepWaterModel m = j.trainModel().get();
int count=0;
while(count++<100) {
Log.info("Iteration: " + count);
// turn the second model into the first model
m.doScoring(tr,null,j._job._key,m.iterations,true);
}
tr.remove();
m.remove();
}
// @Test public void imageToPixels() throws IOException {
// final File imgFile = find_test_file("smalldata/deepwater/imagenet/test2.jpg");
// final float[] dest = new float[28*28*3];
// int count=0;
// Futures fs = new Futures();
// while(count++<10000)
// fs.add(H2O.submitTask(
// new H2O.H2OCountedCompleter() {
// @Override
// public void compute2() {
// try {
// util.img2pixels(imgFile.toString(), 28, 28, 3, dest, 0, null);
// } catch (IOException e) {
// e.printStackTrace();
// }
// tryComplete();
// }
// }));
// fs.blockForPending();
// }
@Test
public void prostateClassification() {
Frame tr = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("smalldata/prostate/prostate.csv"))._key;
p._response_column = "CAPSULE";
p._ignored_columns = new String[]{"ID"};
for (String col : new String[]{"RACE", "DPROS", "DCAPS", "CAPSULE", "GLEASON"}) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
p._seed = 1234;
p._epochs = 500;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue((m._output._training_metrics).auc_obj()._auc > 0.90);
} finally {
if (tr!=null) tr.remove();
if (m!=null) m.remove();
}
}
@Test
public void prostateRegression() {
Frame tr = null;
Frame preds = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("smalldata/prostate/prostate.csv"))._key;
p._response_column = "AGE";
p._ignored_columns = new String[]{"ID"};
for (String col : new String[]{"RACE", "DPROS", "DCAPS", "CAPSULE", "GLEASON"}) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
p._seed = 1234;
p._epochs = 1000;
// p._epochs = 2000;
// p._learning_rate = 0.005; //5e-7;
// p._momentum_start = 0.9;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue((m._output._training_metrics).rmse() < 5);
preds = m.score(p._train.get());
Assert.assertTrue(m.testJavaScoring(p._train.get(),preds,1e-3));
} finally {
if (tr!=null) tr.remove();
if (m!=null) m.remove();
if (preds!=null) preds.remove();
}
}
@Test
public void imageURLs() {
Frame tr = null;
Frame preds = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("smalldata/deepwater/imagenet/binomial_image_urls.csv"))._key;
p._response_column = "C2";
p._network = lenet;
p._epochs = 500;
p._seed = 1234;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue((m._output._training_metrics).auc_obj()._auc > 0.85);
preds = m.score(p._train.get());
Assert.assertTrue(m.testJavaScoring(p._train.get(),preds,1e-3,1e-5,1));
} finally {
if (tr!=null) tr.remove();
if (preds!=null) preds.remove();
if (m!=null) m.remove();
}
}
@Test
public void categorical() {
Frame tr = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr = parse_test_file("smalldata/gbm_test/alphabet_cattest.csv"))._key;
p._response_column = "y";
for (String col : new String[]{"y"}) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
}
DKV.put(tr);
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue((m._output._training_metrics).auc_obj()._auc > 0.90);
} finally {
if (tr!=null) tr.remove();
if (m!=null) m.remove();
}
}
@Test
public void MNISTLenet() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[]{p._response_column}) {
Vec v = tr.remove(col); tr.add(col, v.toCategoricalVec()); v.remove();
v = va.remove(col); va.add(col, v.toCategoricalVec()); v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._train = tr._key;
p._valid = va._key;
p._image_shape = new int[]{28,28};
p._ignore_const_cols = false; //to keep it 28x28
p._channels = 1;
p._network = lenet;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial)(m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr!=null) tr.remove();
if (va!=null) va.remove();
if (m!=null) m.remove();
}
}
@Test
public void MNISTSparse() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[]{p._response_column}) {
Vec v = tr.remove(col); tr.add(col, v.toCategoricalVec()); v.remove();
v = va.remove(col); va.add(col, v.toCategoricalVec()); v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._train = tr._key;
p._valid = va._key;
p._learning_rate = 5e-3;
p._hidden = new int[]{500, 500};
p._sparse = true;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial)(m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr!=null) tr.remove();
if (va!=null) va.remove();
if (m!=null) m.remove();
}
}
@Test
public void MNISTHinton() {
Frame tr = null;
Frame va = null;
DeepWaterModel m = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
if (file != null) {
p._response_column = "C785";
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
NFSFileVec validfv = NFSFileVec.make(valid);
va = ParseDataset.parse(Key.make(), validfv._key);
for (String col : new String[]{p._response_column}) {
Vec v = tr.remove(col); tr.add(col, v.toCategoricalVec()); v.remove();
v = va.remove(col); va.add(col, v.toCategoricalVec()); v.remove();
}
DKV.put(tr);
DKV.put(va);
p._backend = getBackend();
p._hidden = new int[]{1024, 1024, 2048};
p._input_dropout_ratio = 0.1;
p._hidden_dropout_ratios = new double[]{0.5, 0.5, 0.5};
p._stopping_rounds = 0;
p._learning_rate = 1e-3;
p._mini_batch_size = 32;
p._epochs = 20;
p._train = tr._key;
p._valid = va._key;
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsMultinomial)(m._output._validation_metrics)).mean_per_class_error() < 0.05);
}
} finally {
if (tr!=null) tr.remove();
if (va!=null) va.remove();
if (m!=null) m.remove();
}
}
@Test
public void Airlines() {
Frame tr = null;
DeepWaterModel m = null;
Frame[] splits = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
File file = FileUtils.locateFile("smalldata/airlines/allyears2k_headers.zip");
if (file != null) {
p._response_column = "IsDepDelayed";
p._ignored_columns = new String[]{"DepTime","ArrTime","Cancelled","CancellationCode","Diverted","CarrierDelay","WeatherDelay","NASDelay","SecurityDelay","LateAircraftDelay","IsArrDelayed"};
NFSFileVec trainfv = NFSFileVec.make(file);
tr = ParseDataset.parse(Key.make(), trainfv._key);
for (String col : new String[]{p._response_column, "UniqueCarrier", "Origin", "Dest"}) {
Vec v = tr.remove(col); tr.add(col, v.toCategoricalVec()); v.remove();
}
DKV.put(tr);
double[] ratios = ard(0.5, 0.5);
Key[] keys = aro(Key.make("test.hex"), Key.make("train.hex"));
splits = ShuffleSplitFrame.shuffleSplitFrame(tr, keys, ratios, 42);
p._backend = getBackend();
p._train = keys[0];
p._valid = keys[1];
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue(((ModelMetricsBinomial)(m._output._validation_metrics)).auc() > 0.65);
}
} finally {
if (tr!=null) tr.remove();
if (m!=null) m.remove();
if (splits!=null) for(Frame s: splits) s.remove();
}
}
private void MOJOTestImage(DeepWaterParameters.Network network) {
Frame tr = null;
DeepWaterModel m = null;
Frame preds = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key;
p._response_column = "C2";
p._learning_rate = 1e-4;
p._network = network;
p._mini_batch_size = 4;
p._train_samples_per_iteration = 8;
p._epochs = 1e-3;
m = new DeepWater(p).trainModel().get();
// Score original training frame
preds = m.score(tr);
Assert.assertTrue(m.testJavaScoring(tr,preds,1e-3));
preds.remove(0).remove();
double logloss = ModelMetricsMultinomial.make(preds, tr.vec(p._response_column)).logloss();
Assert.assertTrue(Math.abs(logloss - ((ModelMetricsMultinomial)m._output._training_metrics).logloss()) < 1e-3);
} finally {
if (tr!=null) tr.remove();
if (m!=null) m.remove();
if (preds!=null) preds.remove();
}
}
@Test public void MOJOTestImageLenet() { MOJOTestImage(lenet); }
@Test public void MOJOTestImageInception() { MOJOTestImage(inception_bn); }
@Test public void MOJOTestImageAlexnet() { MOJOTestImage(alexnet); }
@Ignore @Test public void MOJOTestImageResnet() { MOJOTestImage(resnet); }
@Test public void MOJOTestImageVGG() { MOJOTestImage(vgg); }
@Ignore @Test public void MOJOTestImageGooglenet() { MOJOTestImage(googlenet); }
private void MOJOTest(Model.Parameters.CategoricalEncodingScheme categoricalEncodingScheme, boolean enumCols, boolean standardize) {
Frame tr = null;
Frame tr2 = null;
Frame tr3 = null;
DeepWaterModel m = null;
Frame preds = null;
Frame preds2 = null;
Frame preds3 = null;
try {
DeepWaterParameters p = new DeepWaterParameters();
tr = parse_test_file("smalldata/prostate/prostate.csv");
p._response_column = "CAPSULE";
for (String col : new String[]{p._response_column}) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
}
if (enumCols) {
for (String col : new String[]{"RACE", "DPROS", "DCAPS", "GLEASON"}) {
Vec v = tr.remove(col);
tr.add(col, v.toCategoricalVec());
v.remove();
}
}
DKV.put(tr);
p._train = tr._key;
p._ignored_columns = new String[]{"ID"};
p._backend = getBackend();
p._seed = 12345;
p._epochs = 50;
p._categorical_encoding = categoricalEncodingScheme;
p._standardize = standardize;
p._hidden = new int[]{50, 50};
m = new DeepWater(p).trainModel().get();
// Score original training frame
preds = m.score(tr);
Assert.assertTrue(m.testJavaScoring(tr,preds,1e-3));
double auc = ModelMetricsBinomial.make(preds.vec(2), tr.vec(p._response_column)).auc();
Assert.assertTrue(Math.abs(auc - ((ModelMetricsBinomial)m._output._training_metrics).auc()) < 1e-3);
if (standardize)
Assert.assertTrue(auc > 0.7);
// Score all numeric frame (cols in the right order) - do the transformation to enum on the fly
tr2 = parse_test_file("smalldata/prostate/prostate.csv");
for (String col : new String[]{p._response_column}) {
tr2.add(col, tr2.remove(col)); //DO NOT CONVERT TO ENUM
}
if (enumCols) {
for (String col : new String[]{"RACE", "DPROS", "DCAPS", "GLEASON"}) {
tr2.add(col, tr2.remove(col)); //DO NOT CONVERT TO ENUM
}
}
preds2 = m.score(tr2);
auc = ModelMetricsBinomial.make(preds2.vec(2), tr2.vec(p._response_column)).auc();
Assert.assertTrue(Math.abs(auc - ((ModelMetricsBinomial)m._output._training_metrics).auc()) < 1e-3);
if (standardize)
Assert.assertTrue(auc > 0.7);
// Score all numeric frame (cols in the wrong order) - do the transformation to enum on the fly
tr3 = parse_test_file("smalldata/prostate/prostate.csv");
preds3 = m.score(tr3);
auc = ModelMetricsBinomial.make(preds3.vec(2), tr3.vec(p._response_column)).auc();
Assert.assertTrue(Math.abs(auc - ((ModelMetricsBinomial)m._output._training_metrics).auc()) < 1e-3);
if (standardize)
Assert.assertTrue(auc > 0.7);
} finally {
if (tr!=null) tr.remove();
if (tr2!=null) tr2.remove();
if (tr3!=null) tr3.remove();
if (m!=null) m.remove();
if (preds!=null) preds.remove();
if (preds2!=null) preds2.remove();
if (preds3!=null) preds3.remove();
}
}
@Test public void MOJOTestNumericNonStandardized() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.AUTO, false, false);}
@Test public void MOJOTestNumeric() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.AUTO, false, true);}
@Test public void MOJOTestCatInternal() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.OneHotInternal, true, true);}
@Test public void MOJOTestCatExplicit() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.OneHotExplicit, true, true);}
@Test public void MOJOTestCatEigen() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.Eigen, true, true);}
@Test public void MOJOTestCatBinary() { MOJOTest(Model.Parameters.CategoricalEncodingScheme.Binary, true, true);}
@Test
public void testCheckpointForwards() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = tfr._key;
p._epochs = 10;
p._response_column = "C5";
p._hidden = new int[]{2,2};
p._seed = 0xdecaf;
p._stopping_rounds = 0;
dl = new DeepWater(p).trainModel().get();
DeepWaterParameters parms2 = (DeepWaterParameters) p.clone();
parms2._epochs = 20;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
Assert.assertTrue(dl2.epoch_counter > 20);
} finally {
if (tfr != null) tfr.delete();
if (dl != null) dl.delete();
if (dl2 != null) dl2.delete();
}
}
@Test
public void testCheckpointBackwards() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
DeepWaterParameters p = new DeepWaterParameters();
p._backend = getBackend();
p._train = tfr._key;
p._epochs = 10;
p._response_column = "C5";
p._hidden = new int[]{2, 2};
p._seed = 0xdecaf;
dl = new DeepWater(p).trainModel().get();
DeepWaterParameters parms2 = (DeepWaterParameters) p.clone();
parms2._epochs = 9;
parms2._checkpoint = dl._key;
try {
dl2 = new DeepWater(parms2).trainModel().get();
Assert.fail("Should toss exception instead of reaching here");
} catch (H2OIllegalArgumentException ex) {
}
} finally {
if (tfr != null) tfr.delete();
if (dl != null) dl.delete();
if (dl2 != null) dl2.delete();
}
}
@Test public void checkpointReporting() {
Scope.enter();
Frame frame = null;
try {
File file = FileUtils.locateFile("smalldata/logreg/prostate.csv");
NFSFileVec trainfv = NFSFileVec.make(file);
frame = ParseDataset.parse(Key.make(), trainfv._key);
DeepWaterParameters p = new DeepWaterParameters();
// populate model parameters
p._backend = getBackend();
p._train = frame._key;
p._response_column = "CAPSULE"; // last column is the response
p._activation = DeepWaterParameters.Activation.Rectifier;
p._epochs = 4;
p._train_samples_per_iteration = -1;
p._mini_batch_size = 1;
p._score_duty_cycle = 1;
p._score_interval = 0;
p._overwrite_with_best_model = false;
p._seed = 1234;
// Convert response 'C785' to categorical (digits 1 to 10)
int ci = frame.find("CAPSULE");
Scope.track(frame.replace(ci, frame.vecs()[ci].toCategoricalVec()));
DKV.put(frame);
long start = System.currentTimeMillis();
try { Thread.sleep(1000); } catch( InterruptedException ex ) { } //to avoid rounding issues with printed time stamp (1 second resolution)
DeepWaterModel model = new DeepWater(p).trainModel().get();
long sleepTime = 5; //seconds
try { Thread.sleep(sleepTime*1000); } catch( InterruptedException ex ) { }
// checkpoint restart after sleep
DeepWaterParameters p2 = (DeepWaterParameters)p.clone();
p2._checkpoint = model._key;
p2._epochs *= 2;
DeepWaterModel model2 = null;
try {
model2 = new DeepWater(p2).trainModel().get();
long end = System.currentTimeMillis();
TwoDimTable table = model2._output._scoring_history;
double priorDurationDouble=0;
long priorTimeStampLong=0;
DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
for (int i=0; i<table.getRowDim(); ++i) {
// Check that timestamp is correct, and growing monotonically
String timestamp = (String)table.get(i,0);
long timeStampLong = fmt.parseMillis(timestamp);
Assert.assertTrue("Timestamp must be later than outside timer start", timeStampLong >= start);
Assert.assertTrue("Timestamp must be earlier than outside timer end", timeStampLong <= end);
Assert.assertTrue("Timestamp must increase", timeStampLong >= priorTimeStampLong);
priorTimeStampLong = timeStampLong;
// Check that duration is growing monotonically
String duration = (String)table.get(i,1);
duration = duration.substring(0, duration.length()-4); //"x.xxxx sec"
try {
double durationDouble = Double.parseDouble(duration);
Assert.assertTrue("Duration must be >0: " + durationDouble, durationDouble >= 0);
Assert.assertTrue("Duration must increase: " + priorDurationDouble + " -> " + durationDouble, durationDouble >= priorDurationDouble);
Assert.assertTrue("Duration cannot be more than outside timer delta", durationDouble <= (end - start) / 1e3);
priorDurationDouble = durationDouble;
} catch(NumberFormatException ex) {
//skip
}
// Check that epoch counting is good
Assert.assertTrue("Epoch counter must be contiguous", (Double)table.get(i,3) == i); //1 epoch per step
Assert.assertTrue("Iteration counter must match epochs", (Integer)table.get(i,4) == i); //1 iteration per step
}
try {
// Check that duration doesn't see the sleep
String durationBefore = (String)table.get((int)(p._epochs),1);
durationBefore = durationBefore.substring(0, durationBefore.length()-4);
String durationAfter = (String)table.get((int)(p._epochs+1),1);
durationAfter = durationAfter.substring(0, durationAfter.length()-4);
Assert.assertTrue("Duration must be smooth", Double.parseDouble(durationAfter) - Double.parseDouble(durationBefore) < sleepTime+1);
// Check that time stamp does see the sleep
String timeStampBefore = (String)table.get((int)(p._epochs),0);
long timeStampBeforeLong = fmt.parseMillis(timeStampBefore);
String timeStampAfter = (String)table.get((int)(p._epochs+1),0);
long timeStampAfterLong = fmt.parseMillis(timeStampAfter);
Assert.assertTrue("Time stamp must experience a delay", timeStampAfterLong-timeStampBeforeLong >= (sleepTime-1/*rounding*/)*1000);
// Check that the training speed is similar before and after checkpoint restart
String speedBefore = (String)table.get((int)(p._epochs),2);
speedBefore = speedBefore.substring(0, speedBefore.length()-9);
double speedBeforeDouble = Double.parseDouble(speedBefore);
String speedAfter = (String)table.get((int)(p._epochs+1),2);
speedAfter = speedAfter.substring(0, speedAfter.length()-9);
double speedAfterDouble = Double.parseDouble(speedAfter);
Assert.assertTrue("Speed shouldn't change more than 50%", Math.abs(speedAfterDouble-speedBeforeDouble)/speedBeforeDouble < 0.5); //expect less than 50% change in speed
} catch(NumberFormatException ex) {
//skip runtimes > 1 minute (too hard to parse into seconds here...).
}
} finally {
if (model != null) model.delete();
if (model2 != null) model2.delete();
}
} finally {
if (frame!=null) frame.remove();
Scope.exit();
}
}
@Test
public void testNumericalExplosion() {
for (boolean ae : new boolean[]{
// true,
false
}) {
Frame tfr = null;
DeepWaterModel dl = null;
Frame pred = null;
try {
tfr = parse_test_file("./smalldata/junit/two_spiral.csv");
for (String s : new String[]{
"Class"
}) {
Vec resp = tfr.vec(s).toCategoricalVec();
tfr.remove(s).remove();
tfr.add(s, resp);
DKV.put(tfr);
}
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = tfr._key;
parms._epochs = 100;
parms._response_column = "Class";
parms._autoencoder = ae;
parms._train_samples_per_iteration = 10;
parms._hidden = new int[]{10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
parms._learning_rate = 1e10;
parms._standardize = false;
// Build a first model; all remaining models should be equal
DeepWater job = new DeepWater(parms);
try {
dl = job.trainModel().get();
Assert.fail("Should toss exception instead of reaching here");
} catch( RuntimeException de ) {
// OK
}
dl = DKV.getGet(job.dest());
try {
pred = dl.score(tfr);
Assert.fail("Should toss exception instead of reaching here");
} catch ( RuntimeException ex) {
// OK
}
try {
dl.getMojo();
Assert.fail("Should toss exception instead of reaching here");
} catch ( RuntimeException ex) {
System.err.println(ex.getMessage());
// OK
}
Assert.assertTrue(dl.model_info()._unstable);
Assert.assertTrue(dl._output._job.isCrashed());
} finally {
if (tfr != null) tfr.delete();
if (dl != null) dl.delete();
if (pred != null) pred.delete();
}
}
}
// ------- Text conversions
@Test
public void textsToArrayTest() throws IOException {
ArrayList<String> texts = new ArrayList<>();
ArrayList<String> labels = new ArrayList<>();
texts.add("the rock is destined to be the 21st century's new \" conan \" and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .");
texts.add("the gorgeously elaborate continuation of \" the lord of the rings \" trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .");
texts.add("effective but too-tepid biopic");
labels.add("pos");
labels.add("pos");
labels.add("pos");
texts.add("simplistic , silly and tedious .");
texts.add("it's so laddish and juvenile , only teenage boys could possibly find it funny .");
texts.add("exploitative and largely devoid of the depth or sophistication that would make watching such a graphic treatment of the crimes bearable .");
labels.add("neg");
labels.add("neg");
labels.add("neg");
ArrayList<int[]> coded = StringUtils.texts2array(texts);
// System.out.println(coded);
for (int[] a : coded) {
System.out.println(Arrays.toString(a));
}
System.out.println("rows " + coded.size() + " cols " + coded.get(0).length);
Assert.assertEquals(6, coded.size());
Assert.assertEquals(38, coded.get(0).length);
}
@Ignore
@Test
public void tweetsToArrayTest() throws IOException {
ArrayList<String> texts = new ArrayList<>();
ArrayList<String> labels = new ArrayList<>();
{
FileInputStream is = new FileInputStream("/home/magnus/tweets.txt");
BufferedReader br = new BufferedReader(new InputStreamReader(is));
String line;
while ((line = br.readLine()) != null) {
texts.add(line);
}
is.close();
}
{
FileInputStream is = new FileInputStream("/home/magnus/labels.txt");
BufferedReader br = new BufferedReader(new InputStreamReader(is));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
}
is.close();
}
ArrayList<int[]> coded = StringUtils.texts2array(texts);
// System.out.println(coded);
// for (int[] a : coded) {
// System.out.println(Arrays.toString(a));
// }
System.out.println("rows " + coded.size() + " cols " + coded.get(0).length);
Assert.assertEquals(1390, coded.size());
Assert.assertEquals(35, coded.get(0).length);
}
/*
public ArrayList<int[]> texts2arrayOnehot(ArrayList<String> texts) {
int maxlen = 0;
int index = 0;
Map<String, Integer> dict = new HashMap<>();
dict.put(PADDING_SYMBOL, index);
index += 1;
for (String text : texts) {
String[] tokens = tokenize(text);
for (String token : tokens) {
if (!dict.containsKey(token)) {
dict.put(token, index);
index += 1;
}
}
int len = tokens.length;
if (len > maxlen) maxlen = len;
}
System.out.println(dict);
System.out.println("maxlen " + maxlen);
System.out.println("dict size " + dict.size());
Assert.assertEquals(38, maxlen);
Assert.assertEquals(88, index);
Assert.assertEquals(88, dict.size());
ArrayList<int[]> array = new ArrayList<>();
for (String text: texts) {
ArrayList<int[]> data = tokensToArray(tokenize(text), maxlen, dict);
System.out.println(text);
System.out.println(" rows " + data.size() + " cols " + data.get(0).length);
//for (int[] x : data) {
// System.out.println(Arrays.toString(x));
//}
array.addAll(data);
}
return array;
}
*/
@Test
public void testCheckpointOverwriteWithBestModel() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[]{0.8},new Key[]{Key.make("train"),Key.make("valid")},null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 1;
parms._response_column = "C5";
parms._hidden = new int[]{50, 50};
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepWater(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial)dl._output._validation_metrics).logloss();
DeepWaterParameters parms2 = (DeepWaterParameters)parms.clone();
parms2._epochs = 10;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial)dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null) tfr.delete();
if (dl != null) dl.delete();
if (dl2 != null) dl2.delete();
if (train != null) train.delete();
if (valid != null) valid.delete();
}
}
// Check that the restarted model honors the previous model as a best model so far
@Test
public void testCheckpointOverwriteWithBestModel2() {
Frame tfr = null;
DeepWaterModel dl = null;
DeepWaterModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[]{0.8},new Key[]{Key.make("train"),Key.make("valid")},null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepWaterParameters parms = new DeepWaterParameters();
parms._backend = getBackend();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 10;
parms._response_column = "C5";
parms._hidden = new int[]{50, 50};
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepWater(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial)dl._output._validation_metrics).logloss();
DeepWaterParameters parms2 = (DeepWaterParameters)parms.clone();
parms2._epochs = 20;
parms2._checkpoint = dl._key;
dl2 = new DeepWater(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial)dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null) tfr.delete();
if (dl != null) dl.delete();
if (dl2 != null) dl2.delete();
if (train != null) train.delete();
if (valid != null) valid.delete();
}
}
}