package samples;
import hex.Layer;
import hex.Layer.VecSoftmax;
import hex.Layer.VecsInput;
import hex.MnistCanvas;
import hex.NeuralNet;
import hex.Trainer;
import samples.expert.NeuralNetMnist;
import water.fvec.Vec;
import javax.swing.*;
public class NeuralNetMnistPretrain extends NeuralNetMnist {
public static void main(String[] args) throws Exception {
Class job = Class.forName(Thread.currentThread().getStackTrace()[1].getClassName());
samples.launchers.CloudLocal.launch(job, 1);
}
@Override protected Layer[] build(Vec[] data, Vec labels, VecsInput inputStats, VecSoftmax outputStats) {
Layer[] ls = new Layer[4];
ls[0] = new VecsInput(data, inputStats);
// ls[1] = new Layer.RectifierDropout(1024);
// ls[2] = new Layer.RectifierDropout(1024);
ls[1] = new Layer.Tanh(50);
ls[2] = new Layer.Tanh(50);
ls[3] = new VecSoftmax(labels, outputStats);
// Parameters for MNIST run
NeuralNet p = new NeuralNet();
p.rate = 0.01; //only used for NN run after pretraining
p.activation = NeuralNet.Activation.Tanh;
p.loss = NeuralNet.Loss.CrossEntropy;
// p.rate_annealing = 1e-6f;
// p.max_w2 = 15;
// p.momentum_start = 0.5f;
// p.momentum_ramp = 60000 * 300;
// p.momentum_stable = 0.99f;
// p.l1 = .00001f;
// p.l2 = .00f;
p.initial_weight_distribution = NeuralNet.InitialWeightDistribution.UniformAdaptive;
// p.initial_weight_scale = 1;
for( int i = 0; i < ls.length; i++ ) {
ls[i].init(ls, i, p);
}
return ls;
}
@Override protected void startTraining(Layer[] ls) {
int pretrain_epochs = 2;
preTrain(ls, pretrain_epochs);
// actual run
int epochs = 0;
if (epochs > 0) {
// _trainer = new Trainer.Direct(ls, epochs, self());
_trainer = new Trainer.Threaded(ls, epochs, self(), -1);
// Basic visualization of images and weights
JFrame frame = new JFrame("H2O Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);//_trainer = new Trainer.MapReduce(ls, epochs, self());
_trainer.start();
_trainer.join();
}
}
final private void preTrain(Layer[] ls, int epochs) {
for( int i = 1; i < ls.length - 1; i++ ) {
System.out.println("Pre-training level " + i);
long time = System.nanoTime();
preTrain(ls, i, epochs);
System.out.println((int) ((System.nanoTime() - time) / 1e6) + " ms");
}
}
final private void preTrain(Layer[] ls, int index, int epochs) {
// Build a network with same layers below 'index', and an auto-encoder at the top
Layer[] pre = new Layer[index + 2];
VecsInput input = (VecsInput) ls[0];
pre[0] = new VecsInput(input.vecs, input);
pre[0].init(pre, 0, ls[0].params); //clone the parameters
for( int i = 1; i < index; i++ ) {
//pre[i] = new Layer.Rectifier(ls[i].units);
pre[i] = new Layer.Tanh(ls[i].units);
Layer.shareWeights(ls[i], pre[i]);
pre[i].init(pre, i, ls[i].params); //share the parameters
pre[i].params.rate = 0; //turn off training for these layers
}
// Auto-encoder is a layer and a reverse layer on top
//pre[index] = new Layer.Rectifier(ls[index].units);
//pre[index + 1] = new Layer.RectifierPrime(ls[index - 1].units);
pre[index] = new Layer.Tanh(ls[index].units);
pre[index].init(pre, index, ls[index].params);
pre[index].params.rate = 1e-5;
pre[index+1] = new Layer.TanhPrime(ls[index-1].units);
pre[index+1].init(pre, index + 1, pre[index].params);
pre[index+1].params.rate = 1e-5;
Layer.shareWeights(ls[index], pre[index]);
Layer.shareWeights(ls[index], pre[index+1]);
_trainer = new Trainer.Direct(pre, epochs, self());
// Basic visualization of images and weights
JFrame frame = new JFrame("H2O Pre-Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
_trainer.start();
_trainer.join();
}
}