package hex;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import junit.framework.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.JUnitRunnerDebug;
import water.Key;
import water.TestUtil;
import water.UKV;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;
import water.util.Log;
public class DeepLearningSpiralsTest extends TestUtil {
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
}
@Test public void run() {
Key file = NFSFileVec.make(find_test_file("smalldata/neural/two_spiral.data"));
Frame frame = ParseDataset2.parse(Key.make(), new Key[]{file});
Key dest = Key.make("spirals2");
for (boolean sparse : new boolean[]{true,false}) {
for (boolean col_major : new boolean[]{false}) {
if (!sparse && col_major) continue;
// build the model
{
DeepLearning p = new DeepLearning();
p.seed = 0xbabe;
p.epochs = 10000;
p.hidden = new int[]{100};
p.sparse = sparse;
p.col_major = col_major;
p.activation = DeepLearning.Activation.Tanh;
p.max_w2 = Float.POSITIVE_INFINITY;
p.l1 = 0;
p.l2 = 0;
p.initial_weight_distribution = DeepLearning.InitialWeightDistribution.Normal;
p.initial_weight_scale = 2.5;
p.loss = DeepLearning.Loss.CrossEntropy;
p.source = frame;
p.response = frame.lastVec();
p.validation = null;
p.score_interval = 2;
p.ignored_cols = null;
p.train_samples_per_iteration = 0; //sync once per period
p.quiet_mode = true;
p.fast_mode = true;
p.ignore_const_cols = true;
p.nesterov_accelerated_gradient = true;
p.classification = true;
p.diagnostics = true;
p.expert_mode = true;
p.score_training_samples = 1000;
p.score_validation_samples = 10000;
p.shuffle_training_data = false;
p.force_load_balance = false;
p.replicate_training_data = false;
p.destination_key = dest;
p.adaptive_rate = true;
p.reproducible = true;
p.rho = 0.99;
p.epsilon = 5e-3;
p.invoke();
}
// score and check result
{
DeepLearningModel mymodel = UKV.get(dest);
double error = mymodel.error();
if (error >= 0.025) {
Assert.fail("Classification error is not less than 0.025, but " + error + ".");
}
mymodel.delete();
mymodel.delete_best_model();
}
}
}
frame.delete();
}
}