package org.deeplearning4j.examples.misc.modelsaving;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
/**
* A very simple example for saving and loading a MultiLayerNetwork
*
* @author Alex Black
*/
public class SaveLoadMultiLayerNetwork {
public static void main(String[] args) throws Exception {
//Define a simple MultiLayerNetwork:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS)
.learningRate(0.1)
.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build())
.backprop(true).pretrain(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//Save the model
File locationToSave = new File("MyMultiLayerNetwork.zip"); //Where to save the network. Note: the file is in .zip format - can be opened externally
boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
ModelSerializer.writeModel(net, locationToSave, saveUpdater);
//Load the model
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
System.out.println("Saved and loaded parameters are equal: " + net.params().equals(restored.params()));
System.out.println("Saved and loaded configurations are equal: " + net.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations()));
}
}