package org.deeplearning4j.examples.misc.modelsaving; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; /** * A very simple example for saving and loading a ComputationGraph * * @author Alex Black */ public class SaveLoadComputationGraph { public static void main(String[] args) throws Exception { //Define a simple ComputationGraph: ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.XAVIER) .updater(Updater.NESTEROVS) .learningRate(0.1) .graphBuilder() .addInputs("in") .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build(), "in") .addLayer("layer1", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0") .setOutputs("layer1") .backprop(true).pretrain(false).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); //Save the model File locationToSave = new File("MyComputationGraph.zip"); //Where to save the network. Note: the file is in .zip format - can be opened externally boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future ModelSerializer.writeModel(net, locationToSave, saveUpdater); //Load the model ComputationGraph restored = ModelSerializer.restoreComputationGraph(locationToSave); System.out.println("Saved and loaded parameters are equal: " + net.params().equals(restored.params())); System.out.println("Saved and loaded configurations are equal: " + net.getConfiguration().equals(restored.getConfiguration())); } }