package org.deeplearning4j.ui;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.deeplearning4j.ui.weights.ModelAndGradient;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.serde.jackson.ndarray.NDArrayDeSerializer;
import org.nd4j.shade.serde.jackson.ndarray.NDArraySerializer;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
public class TestSerialization {
@Test
public void testModelSerde() throws Exception {
ObjectMapper mapper = getMapper();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1000)
.learningRate(1e-1f)
.layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(4).nOut(3)
.corruptionLevel(0.6).sparsity(0.5)
.lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build())
.build();
DataSet d2 = new IrisDataSetIterator(150, 150).next();
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf,
Arrays.asList(new ScoreIterationListener(1), new HistogramIterationListener(1)), 0, params,
true);
da.setInput(input);
ModelAndGradient g = new ModelAndGradient(da);
String json = mapper.writeValueAsString(g);
ModelAndGradient read = mapper.readValue(json, ModelAndGradient.class);
assertEquals(g, read);
}
public ObjectMapper getMapper() {
ObjectMapper mapper = new ObjectMapper();
SimpleModule nd4j = new SimpleModule("nd4j");
nd4j.addDeserializer(INDArray.class, new NDArrayDeSerializer());
nd4j.addSerializer(INDArray.class, new NDArraySerializer());
mapper.registerModule(nd4j);
return mapper;
}
}