package org.deeplearning4j.ui.play;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter;
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport;
import org.deeplearning4j.ui.stats.impl.SbeStatsReport;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.storage.impl.SbeStorageMetaData;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 10/11/2016.
*/
@Ignore
public class TestRemoteReceiver {
@Test
@Ignore
public void testRemoteBasic() throws Exception {
List<Persistable> updates = new ArrayList<>();
List<Persistable> staticInfo = new ArrayList<>();
List<StorageMetaData> metaData = new ArrayList<>();
CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
UIServer s = UIServer.getInstance();
s.enableRemoteListener(collectionRouter, false);
RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
SbeStatsReport update1 = new SbeStatsReport();
update1.setDeviceCurrentBytes(new long[] {1, 2});
update1.reportIterationCount(10);
update1.reportIDs("sid", "tid", "wid", 123456);
update1.reportPerformance(10, 20, 30, 40, 50);
SbeStatsReport update2 = new SbeStatsReport();
update2.setDeviceCurrentBytes(new long[] {3, 4});
update2.reportIterationCount(20);
update2.reportIDs("sid2", "tid2", "wid2", 123456);
update2.reportPerformance(11, 21, 31, 40, 50);
StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
StorageMetaData smd2 =
new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
init1.reportIDs("sid", "wid", "tid", 3145253452L);
init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
remoteRouter.putUpdate(update1);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd1);
Thread.sleep(100);
remoteRouter.putStaticInfo(init1);
Thread.sleep(100);
remoteRouter.putUpdate(update2);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd2);
Thread.sleep(2000);
assertEquals(2, metaData.size());
assertEquals(2, updates.size());
assertEquals(1, staticInfo.size());
assertEquals(Arrays.asList(update1, update2), updates);
assertEquals(Arrays.asList(smd1, smd2), metaData);
assertEquals(Collections.singletonList(init1), staticInfo);
}
@Test
@Ignore
public void testRemoteFull() throws Exception {
//Use this in conjunction with startRemoteUI()
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000");
net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 500; i++) {
net.fit(iter);
// Thread.sleep(100);
Thread.sleep(100);
}
}
@Test
@Ignore
public void startRemoteUI() throws Exception {
UIServer s = UIServer.getInstance();
s.enableRemoteListener();
Thread.sleep(1000000);
}
}