package org.deeplearning4j.util; import org.datavec.api.transform.transform.doubletransform.MinMaxNormalizer; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; /** * @author raver119@gmail.com */ public class ModelSerializerTest { @Test public void testWriteMLNModel() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).regularization(true).l1(0.01) .l2(0.01).learningRate(0.1).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test public void testWriteMlnModelInputStream() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).regularization(true).l1(0.01) .l2(0.01).learningRate(0.1).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); FileOutputStream fos = new FileOutputStream(tempFile); ModelSerializer.writeModel(net, fos, true); // checking adding of DataNormalization to the model file NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(); DataSetIterator iter = new IrisDataSetIterator(150, 150); scaler.fit(iter); ModelSerializer.addNormalizerToModel(tempFile, scaler); NormalizerMinMaxScaler restoredScaler = ModelSerializer.restoreNormalizerFromFile(tempFile); assertNotEquals(null, scaler.getMax()); assertEquals(scaler.getMax(), restoredScaler.getMax()); assertEquals(scaler.getMin(), restoredScaler.getMin()); FileInputStream fis = new FileInputStream(tempFile); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test public void testWriteCGModel() throws Exception { ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) .build(), "dense") .setOutputs("out").pretrain(false).backprop(true).build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test public void testWriteCGModelInputStream() throws Exception { ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) .build(), "dense") .setOutputs("out").pretrain(false).backprop(true).build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); FileInputStream fis = new FileInputStream(tempFile); ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } private DataSet trivialDataSet() { INDArray inputs = Nd4j.create(new float[] { 1.0f, 2.0f, 3.0f }); INDArray labels = Nd4j.create(new float[] { 4.0f, 5.0f, 6.0f }); return new DataSet(inputs, labels); } private ComputationGraph simpleComputationGraph() { ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.1) .graphBuilder() .addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in") .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense") .setOutputs("out") .pretrain(false) .backprop(true) .build(); return new ComputationGraph(config); } @Test public void testSaveRestoreNormalizerFromInputStream() throws Exception { DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); ComputationGraph cg = simpleComputationGraph(); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); ModelSerializer.addNormalizerToModel(tempFile, norm); FileInputStream fis = new FileInputStream(tempFile); NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis); assertNotEquals(null, restored); DataSet dataSet2 = dataSet.copy(); norm.preProcess(dataSet2); assertNotEquals(dataSet.getFeatures(), dataSet2.getFeatures()); restored.revert(dataSet2); assertEquals(dataSet.getFeatures(), dataSet2.getFeatures()); } @Test public void testRestoreUnsavedNormalizerFromInputStream() throws Exception { DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); ComputationGraph cg = simpleComputationGraph(); cg.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(cg, tempFile, true); FileInputStream fis = new FileInputStream(tempFile); NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis); assertEquals(null, restored); } }