package hex.deeplearning;
import hex.deeplearning.DeepLearningModel.DeepLearningParameters;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.parser.ParseDataset;
import water.util.FileUtils;
import water.util.TwoDimTable;
import java.io.File;
public class DeepLearningCheckpointReporting extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void run() {
Scope.enter();
Frame frame = null;
try {
NFSFileVec trainfv = TestUtil.makeNfsFileVec("smalldata/logreg/prostate.csv");
frame = ParseDataset.parse(Key.make(), trainfv._key);
DeepLearningParameters p = new DeepLearningParameters();
// populate model parameters
p._train = frame._key;
p._response_column = "CAPSULE"; // last column is the response
p._activation = DeepLearningParameters.Activation.Rectifier;
p._epochs = 4;
p._train_samples_per_iteration = -1;
p._score_duty_cycle = 1;
p._score_interval = 0;
p._overwrite_with_best_model = false;
p._classification_stop = -1;
p._seed = 1234;
p._reproducible = true;
// Convert response 'C785' to categorical (digits 1 to 10)
int ci = frame.find("CAPSULE");
Scope.track(frame.replace(ci, frame.vecs()[ci].toCategoricalVec()));
DKV.put(frame);
long start = System.currentTimeMillis();
try { Thread.sleep(1000); } catch( InterruptedException ex ) { } //to avoid rounding issues with printed time stamp (1 second resolution)
DeepLearningModel model = new DeepLearning(p).trainModel().get();
long sleepTime = 5; //seconds
try { Thread.sleep(sleepTime*1000); } catch( InterruptedException ex ) { }
// checkpoint restart after sleep
DeepLearningParameters p2 = (DeepLearningParameters)p.clone();
p2._checkpoint = model._key;
p2._epochs *= 2;
DeepLearningModel model2 = null;
try {
model2 = new DeepLearning(p2).trainModel().get();
long end = System.currentTimeMillis();
TwoDimTable table = model2._output._scoring_history;
double priorDurationDouble=0;
long priorTimeStampLong=0;
DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
for (int i=0; i<table.getRowDim(); ++i) {
// Check that timestamp is correct, and growing monotonically
String timestamp = (String)table.get(i,0);
long timeStampLong = fmt.parseMillis(timestamp);
Assert.assertTrue("Timestamp must be later than outside timer start", timeStampLong >= start);
Assert.assertTrue("Timestamp must be earlier than outside timer end", timeStampLong <= end);
Assert.assertTrue("Timestamp must increase", timeStampLong >= priorTimeStampLong);
priorTimeStampLong = timeStampLong;
// Check that duration is growing monotonically
String duration = (String)table.get(i,1);
duration = duration.substring(0, duration.length()-4); //"x.xxxx sec"
try {
double durationDouble = Double.parseDouble(duration);
Assert.assertTrue("Duration must be >0: " + durationDouble, durationDouble >= 0);
Assert.assertTrue("Duration must increase: " + priorDurationDouble + " -> " + durationDouble, durationDouble >= priorDurationDouble);
Assert.assertTrue("Duration cannot be more than outside timer delta", durationDouble <= (end - start) / 1e3);
priorDurationDouble = durationDouble;
} catch(NumberFormatException ex) {
//skip
}
// Check that epoch counting is good
Assert.assertTrue("Epoch counter must be contiguous", (Double)table.get(i,3) == i); //1 epoch per step
Assert.assertTrue("Iteration counter must match epochs", (Integer)table.get(i,4) == i); //1 iteration per step
}
try {
// Check that duration doesn't see the sleep
String durationBefore = (String)table.get((int)(p._epochs),1);
durationBefore = durationBefore.substring(0, durationBefore.length()-4);
String durationAfter = (String)table.get((int)(p._epochs+1),1);
durationAfter = durationAfter.substring(0, durationAfter.length()-4);
double diff = Double.parseDouble(durationAfter) - Double.parseDouble(durationBefore);
Assert.assertTrue("Duration must be smooth; actual " + diff + ", expected at most " + sleepTime + " (before=" + durationBefore + ", after=" + durationAfter + ")", diff < sleepTime+1);
// Check that time stamp does see the sleep
String timeStampBefore = (String)table.get((int)(p._epochs),0);
long timeStampBeforeLong = fmt.parseMillis(timeStampBefore);
String timeStampAfter = (String)table.get((int)(p._epochs+1),0);
long timeStampAfterLong = fmt.parseMillis(timeStampAfter);
Assert.assertTrue("Time stamp must experience a delay", timeStampAfterLong-timeStampBeforeLong >= (sleepTime-1/*rounding*/)*1000);
// Check that the training speed is similar before and after checkpoint restart
String speedBefore = (String)table.get((int)(p._epochs),2);
speedBefore = speedBefore.substring(0, speedBefore.length()-9);
double speedBeforeDouble = Double.parseDouble(speedBefore);
String speedAfter = (String)table.get((int)(p._epochs+1),2);
speedAfter = speedAfter.substring(0, speedAfter.length()-9);
double speedAfterDouble = Double.parseDouble(speedAfter);
Assert.assertTrue("Speed shouldn't change more than 50%", Math.abs(speedAfterDouble-speedBeforeDouble)/speedBeforeDouble < 0.5); //expect less than 50% change in speed
} catch(NumberFormatException ex) {
//skip runtimes > 1 minute (too hard to parse into seconds here...).
}
} finally {
if (model != null) model.delete();
if (model2 != null) model2.delete();
}
} finally {
if (frame!=null) frame.remove();
Scope.exit();
}
}
}