package org.deeplearning4j.parallelism.main;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.Random;
/**
* Created by agibsonccc on 12/29/16.
*/
@Slf4j
public class ParallelWrapperMainTest {
@Test
public void runParallelWrapperMain() throws Exception {
int nChannels = 1;
int outputNum = 10;
// for GPU you usually want to have higher batchSize
int batchSize = 128;
int nEpochs = 10;
int iterations = 1;
int seed = 123;
int uiPort = 9500;
System.setProperty("org.deeplearning4j.ui.port", String.valueOf(uiPort));
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
.regularization(true).l2(0.0005).learningRate(0.01)//.biasLearningRate(0.02)
//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
.momentum(0.9).list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
//Note that nIn needed be specified in later layers
.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build())
.layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).activation(Activation.SOFTMAX).build())
.backprop(true).pretrain(false);
// The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
new ConvolutionLayerSetup(builder, 28, 28, 1);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
File tempModel = new File("tmpmodel.zip");
tempModel.deleteOnExit();
ModelSerializer.writeModel(model, tempModel, false);
File tmp = new File("tmpmodel.bin");
tmp.deleteOnExit();
ParallelWrapperMain parallelWrapperMain = new ParallelWrapperMain();
parallelWrapperMain.runMain(new String[] {"--modelPath", tempModel.getAbsolutePath(),
"--dataSetIteratorFactoryClazz", MnistDataSetIteratorProviderFactory.class.getName(),
"--modelOutputPath", tmp.getAbsolutePath(), "--uiUrl", "localhost:" + uiPort});
Thread.sleep(30000);
}
}