package org.deeplearning4j.ui.stats;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import static org.junit.Assert.assertNotNull;
/**
* Created by Alex on 07/04/2017.
*/
public class TestTransferStatsCollection {
@Test
public void test() throws IOException {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerNetwork net2 =
new TransferLearning.Builder(net)
.fineTuneConfiguration(
new FineTuneConfiguration.Builder().learningRate(0.01).build())
.setFeatureExtractor(0).build();
File f = Files.createTempFile("dl4jTestTransferStatsCollection", "bin").toFile();
f.delete();
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
//Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
f.deleteOnExit();
}
}