package org.deeplearning4j.util; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.ModelConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; import java.util.UUID; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeNotNull; /** * Created by agibsonccc on 12/29/16. */ public class ModelGuesserTest { @Test public void testModelGuess() throws Exception { ClassPathResource sequenceResource = new ClassPathResource("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_model.h5"); assertTrue(sequenceResource.exists()); File f = getTempFile(sequenceResource); Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); assumeNotNull(guess1); ClassPathResource sequenceResource2 = new ClassPathResource("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_model.h5"); assertTrue(sequenceResource2.exists()); File f2 = getTempFile(sequenceResource); Model guess2 = ModelGuesser.loadModelGuess(f2.getAbsolutePath()); assumeNotNull(guess2); } @Test public void testLoadNormalizers() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).regularization(true).l1(0.01) .l2(0.01).learningRate(0.1).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(net, tempFile, true); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0,1); normalizer.fit(new DataSet(Nd4j.rand(new int[]{2,2}),Nd4j.rand(new int[]{2,2}))); ModelSerializer.addNormalizerToModel(tempFile,normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model,net); assertEquals(normalizer,normalizer1); } @Test public void testModelGuesserDl4jModel() throws Exception { int nIn = 5; int nOut = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).regularization(true).l1(0.01) .l2(0.01).learningRate(0.1).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); File tempFile = File.createTempFile("tsfs", "fdfsdf"); tempFile.deleteOnExit(); ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test public void testModelGuessConfig() throws Exception { ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", ModelGuesserTest.class.getClassLoader()); File f = getTempFile(resource); String configFilename = f.getAbsolutePath(); Object conf = ModelGuesser.loadConfigGuess(configFilename); assertTrue(conf instanceof MultiLayerConfiguration); ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json"); File f2 = getTempFile(sequenceResource); Object sequenceConf = ModelGuesser.loadConfigGuess(f2.getAbsolutePath()); assertTrue(sequenceConf instanceof ComputationGraphConfiguration); ClassPathResource resourceDl4j = new ClassPathResource("model.json"); File fDl4j = getTempFile(resourceDl4j); String configFilenameDl4j = fDl4j.getAbsolutePath(); Object confDl4j = ModelGuesser.loadConfigGuess(configFilenameDl4j); assertTrue(confDl4j instanceof ComputationGraphConfiguration); } private File getTempFile(ClassPathResource classPathResource) throws Exception { InputStream is = classPathResource.getInputStream(); File f = new File(UUID.randomUUID().toString()); f.deleteOnExit(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); IOUtils.copy(is, bos); bos.flush(); bos.close(); return f; } }