package org.deeplearning4j.ui.stats;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.ui.stats.api.*;
import org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport;
import org.deeplearning4j.ui.stats.impl.SbeStatsReport;
import org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport;
import org.junit.Assert;
import org.junit.Test;
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
/**
* Created by Alex on 01/10/2016.
*/
public class TestStatsClasses {
@Test
public void testStatsInitializationReport() throws Exception {
boolean[] tf = new boolean[] {true, false};
for (boolean useJ7 : new boolean[] {false, true}) {
//IDs
String sessionID = "sid";
String typeID = "tid";
String workerID = "wid";
long timestamp = -1;
//Hardware info
int jvmAvailableProcessors = 1;
int numDevices = 2;
long jvmMaxMemory = 3;
long offHeapMaxMemory = 4;
long[] deviceTotalMemory = new long[] {5, 6};
String[] deviceDescription = new String[] {"7", "8"};
String hwUID = "8a";
//Software info
String arch = "9";
String osName = "10";
String jvmName = "11";
String jvmVersion = "12";
String jvmSpecVersion = "13";
String nd4jBackendClass = "14";
String nd4jDataTypeName = "15";
String hostname = "15a";
String jvmUID = "15b";
Map<String, String> swEnvInfo = new HashMap<>();
swEnvInfo.put("env15c-1", "SomeData");
swEnvInfo.put("env15c-2", "OtherData");
swEnvInfo.put("env15c-3", "EvenMoreData");
//Model info
String modelClassName = "16";
String modelConfigJson = "17";
String[] modelparamNames = new String[] {"18", "19", "20", "21"};
int numLayers = 22;
long numParams = 23;
for (boolean hasHardwareInfo : tf) {
for (boolean hasSoftwareInfo : tf) {
for (boolean hasModelInfo : tf) {
StatsInitializationReport report;
if (useJ7) {
report = new JavaStatsInitializationReport();
} else {
report = new SbeStatsInitializationReport();
}
report.reportIDs(sessionID, typeID, workerID, timestamp);
if (hasHardwareInfo) {
report.reportHardwareInfo(jvmAvailableProcessors, numDevices, jvmMaxMemory,
offHeapMaxMemory, deviceTotalMemory, deviceDescription, hwUID);
}
if (hasSoftwareInfo) {
report.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion,
nd4jBackendClass, nd4jDataTypeName, hostname, jvmUID, swEnvInfo);
}
if (hasModelInfo) {
report.reportModelInfo(modelClassName, modelConfigJson, modelparamNames, numLayers,
numParams);
}
byte[] asBytes = report.encode();
StatsInitializationReport report2;// = new SbeStatsInitializationReport();
if (useJ7) {
report2 = new JavaStatsInitializationReport();
} else {
report2 = new SbeStatsInitializationReport();
}
report2.decode(asBytes);
assertEquals(report, report2);
assertEquals(sessionID, report2.getSessionID());
assertEquals(typeID, report2.getTypeID());
assertEquals(workerID, report2.getWorkerID());
assertEquals(timestamp, report2.getTimeStamp());
if (hasHardwareInfo) {
assertEquals(jvmAvailableProcessors, report2.getHwJvmAvailableProcessors());
assertEquals(numDevices, report2.getHwNumDevices());
assertEquals(jvmMaxMemory, report2.getHwJvmMaxMemory());
assertEquals(offHeapMaxMemory, report2.getHwOffHeapMaxMemory());
assertArrayEquals(deviceTotalMemory, report2.getHwDeviceTotalMemory());
assertArrayEquals(deviceDescription, report2.getHwDeviceDescription());
assertEquals(hwUID, report2.getHwHardwareUID());
assertTrue(report2.hasHardwareInfo());
} else {
assertFalse(report2.hasHardwareInfo());
}
if (hasSoftwareInfo) {
assertEquals(arch, report2.getSwArch());
assertEquals(osName, report2.getSwOsName());
assertEquals(jvmName, report2.getSwJvmName());
assertEquals(jvmVersion, report2.getSwJvmVersion());
assertEquals(jvmSpecVersion, report2.getSwJvmSpecVersion());
assertEquals(nd4jBackendClass, report2.getSwNd4jBackendClass());
assertEquals(nd4jDataTypeName, report2.getSwNd4jDataTypeName());
assertEquals(jvmUID, report2.getSwJvmUID());
assertEquals(hostname, report2.getSwHostName());
assertEquals(swEnvInfo, report2.getSwEnvironmentInfo());
assertTrue(report2.hasSoftwareInfo());
} else {
assertFalse(report2.hasSoftwareInfo());
}
if (hasModelInfo) {
assertEquals(modelClassName, report2.getModelClassName());
assertEquals(modelConfigJson, report2.getModelConfigJson());
assertArrayEquals(modelparamNames, report2.getModelParamNames());
assertEquals(numLayers, report2.getModelNumLayers());
assertEquals(numParams, report2.getModelNumParams());
assertTrue(report2.hasModelInfo());
} else {
assertFalse(report2.hasModelInfo());
}
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
StatsInitializationReport report3 = (StatsInitializationReport) ois.readObject();
assertEquals(report, report3);
}
}
}
}
}
@Test
public void testStatsInitializationReportNullValues() throws Exception {
//Sanity check: shouldn't have any issues with encoding/decoding null values...
boolean[] tf = new boolean[] {true, false};
for (boolean useJ7 : new boolean[] {false, true}) {
//Hardware info
int jvmAvailableProcessors = 1;
int numDevices = 2;
long jvmMaxMemory = 3;
long offHeapMaxMemory = 4;
long[] deviceTotalMemory = null;
String[] deviceDescription = null;
String hwUID = null;
//Software info
String arch = null;
String osName = null;
String jvmName = null;
String jvmVersion = null;
String jvmSpecVersion = null;
String nd4jBackendClass = null;
String nd4jDataTypeName = null;
String hostname = null;
String jvmUID = null;
Map<String, String> swEnvInfo = null;
//Model info
String modelClassName = null;
String modelConfigJson = null;
String[] modelparamNames = null;
int numLayers = 22;
long numParams = 23;
for (boolean hasHardwareInfo : tf) {
for (boolean hasSoftwareInfo : tf) {
for (boolean hasModelInfo : tf) {
System.out.println(hasHardwareInfo + "\t" + hasSoftwareInfo + "\t" + hasModelInfo);
StatsInitializationReport report;
if (useJ7) {
report = new JavaStatsInitializationReport();
} else {
report = new SbeStatsInitializationReport();
}
report.reportIDs(null, null, null, -1);
if (hasHardwareInfo) {
report.reportHardwareInfo(jvmAvailableProcessors, numDevices, jvmMaxMemory,
offHeapMaxMemory, deviceTotalMemory, deviceDescription, hwUID);
}
if (hasSoftwareInfo) {
report.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion,
nd4jBackendClass, nd4jDataTypeName, hostname, jvmUID, swEnvInfo);
}
if (hasModelInfo) {
report.reportModelInfo(modelClassName, modelConfigJson, modelparamNames, numLayers,
numParams);
}
byte[] asBytes = report.encode();
StatsInitializationReport report2;
if (useJ7) {
report2 = new JavaStatsInitializationReport();
} else {
report2 = new SbeStatsInitializationReport();
}
report2.decode(asBytes);
if (hasHardwareInfo) {
assertEquals(jvmAvailableProcessors, report2.getHwJvmAvailableProcessors());
assertEquals(numDevices, report2.getHwNumDevices());
assertEquals(jvmMaxMemory, report2.getHwJvmMaxMemory());
assertEquals(offHeapMaxMemory, report2.getHwOffHeapMaxMemory());
if (useJ7) {
assertArrayEquals(null, report2.getHwDeviceTotalMemory());
assertArrayEquals(null, report2.getHwDeviceDescription());
} else {
assertArrayEquals(new long[] {0, 0}, report2.getHwDeviceTotalMemory()); //Edge case: nDevices = 2, but missing mem data -> expect long[] of 0s out, due to fixed encoding
assertArrayEquals(new String[] {"", ""}, report2.getHwDeviceDescription()); //As above
}
assertNullOrZeroLength(report2.getHwHardwareUID());
assertTrue(report2.hasHardwareInfo());
} else {
assertFalse(report2.hasHardwareInfo());
}
if (hasSoftwareInfo) {
assertNullOrZeroLength(report2.getSwArch());
assertNullOrZeroLength(report2.getSwOsName());
assertNullOrZeroLength(report2.getSwJvmName());
assertNullOrZeroLength(report2.getSwJvmVersion());
assertNullOrZeroLength(report2.getSwJvmSpecVersion());
assertNullOrZeroLength(report2.getSwNd4jBackendClass());
assertNullOrZeroLength(report2.getSwNd4jDataTypeName());
assertNullOrZeroLength(report2.getSwJvmUID());
assertNull(report2.getSwEnvironmentInfo());
assertTrue(report2.hasSoftwareInfo());
} else {
assertFalse(report2.hasSoftwareInfo());
}
if (hasModelInfo) {
assertNullOrZeroLength(report2.getModelClassName());
assertNullOrZeroLength(report2.getModelConfigJson());
assertNullOrZeroLengthArray(report2.getModelParamNames());
assertEquals(numLayers, report2.getModelNumLayers());
assertEquals(numParams, report2.getModelNumParams());
assertTrue(report2.hasModelInfo());
} else {
assertFalse(report2.hasModelInfo());
}
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
StatsInitializationReport report3 = (StatsInitializationReport) ois.readObject();
assertEquals(report, report3);
}
}
}
}
}
private static void assertNullOrZeroLength(String str) {
assertTrue(str == null || str.length() == 0);
}
private static void assertNullOrZeroLengthArray(String[] str) {
assertTrue(str == null || str.length == 0);
}
@Test
public void testSbeStatsUpdate() throws Exception {
String[] paramNames = new String[] {"param0", "param1"};
String[] layerNames = new String[] {"layer0", "layer1"};
//IDs
String sessionID = "sid";
String typeID = "tid";
String workerID = "wid";
long timestamp = -1;
long time = System.currentTimeMillis();
int duration = 123456;
int iterCount = 123;
long perfRuntime = 1;
long perfTotalEx = 2;
long perfTotalMB = 3;
double perfEPS = 4.0;
double perfMBPS = 5.0;
long memJC = 6;
long memJM = 7;
long memOC = 8;
long memOM = 9;
long[] memDC = new long[] {10, 11};
long[] memDM = new long[] {12, 13};
String gc1Name = "14";
int gcdc1 = 16;
int gcdt1 = 17;
String gc2Name = "18";
int gcdc2 = 20;
int gcdt2 = 21;
double score = 22.0;
Map<String, Double> lrByParam = new HashMap<>();
lrByParam.put(paramNames[0], 22.5);
lrByParam.put(paramNames[1], 22.75);
Map<String, Histogram> pHist = new HashMap<>();
pHist.put(paramNames[0], new Histogram(23, 24, 2, new int[] {25, 26}));
pHist.put(paramNames[1], new Histogram(27, 28, 3, new int[] {29, 30, 31}));
Map<String, Histogram> gHist = new HashMap<>();
gHist.put(paramNames[0], new Histogram(230, 240, 2, new int[] {250, 260}));
gHist.put(paramNames[1], new Histogram(270, 280, 3, new int[] {290, 300, 310}));
Map<String, Histogram> uHist = new HashMap<>();
uHist.put(paramNames[0], new Histogram(32, 33, 2, new int[] {34, 35}));
uHist.put(paramNames[1], new Histogram(36, 37, 3, new int[] {38, 39, 40}));
Map<String, Histogram> aHist = new HashMap<>();
aHist.put(layerNames[0], new Histogram(41, 42, 2, new int[] {43, 44}));
aHist.put(layerNames[1], new Histogram(45, 46, 3, new int[] {47, 48, 47}));
Map<String, Double> pMean = new HashMap<>();
pMean.put(paramNames[0], 49.0);
pMean.put(paramNames[1], 50.0);
Map<String, Double> gMean = new HashMap<>();
gMean.put(paramNames[0], 49.1);
gMean.put(paramNames[1], 50.1);
Map<String, Double> uMean = new HashMap<>();
uMean.put(paramNames[0], 51.0);
uMean.put(paramNames[1], 52.0);
Map<String, Double> aMean = new HashMap<>();
aMean.put(layerNames[0], 53.0);
aMean.put(layerNames[1], 54.0);
Map<String, Double> pStd = new HashMap<>();
pStd.put(paramNames[0], 55.0);
pStd.put(paramNames[1], 56.0);
Map<String, Double> gStd = new HashMap<>();
gStd.put(paramNames[0], 55.1);
gStd.put(paramNames[1], 56.1);
Map<String, Double> uStd = new HashMap<>();
uStd.put(paramNames[0], 57.0);
uStd.put(paramNames[1], 58.0);
Map<String, Double> aStd = new HashMap<>();
aStd.put(layerNames[0], 59.0);
aStd.put(layerNames[1], 60.0);
Map<String, Double> pMM = new HashMap<>();
pMM.put(paramNames[0], 61.0);
pMM.put(paramNames[1], 62.0);
Map<String, Double> gMM = new HashMap<>();
gMM.put(paramNames[0], 61.1);
gMM.put(paramNames[1], 62.1);
Map<String, Double> uMM = new HashMap<>();
uMM.put(paramNames[0], 63.0);
uMM.put(paramNames[1], 64.0);
Map<String, Double> aMM = new HashMap<>();
aMM.put(layerNames[0], 65.0);
aMM.put(layerNames[1], 66.0);
List<Serializable> metaDataList = new ArrayList<>();
metaDataList.add("meta1");
metaDataList.add("meta2");
metaDataList.add("meta3");
Class<?> metaDataClass = String.class;
boolean[] tf = new boolean[] {true, false};
boolean[][] tf4 = new boolean[][] {{false, false, false, false}, {true, false, false, false},
{false, true, false, false}, {false, false, true, false}, {false, false, false, true},
{true, true, true, true}};
//Total tests: 2^6 x 6^3 = 13,824 separate tests
int testCount = 0;
for (boolean collectPerformanceStats : tf) {
for (boolean collectMemoryStats : tf) {
for (boolean collectGCStats : tf) {
for (boolean collectScore : tf) {
for (boolean collectLearningRates : tf) {
for (boolean collectMetaData : tf) {
for (boolean[] collectHistograms : tf4) {
for (boolean[] collectMeanStdev : tf4) {
for (boolean[] collectMM : tf4) {
SbeStatsReport report = new SbeStatsReport();
report.reportIDs(sessionID, typeID, workerID, time);
report.reportStatsCollectionDurationMS(duration);
report.reportIterationCount(iterCount);
if (collectPerformanceStats) {
report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS,
perfMBPS);
}
if (collectMemoryStats) {
report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
}
if (collectGCStats) {
report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
}
if (collectScore) {
report.reportScore(score);
}
if (collectLearningRates) {
report.reportLearningRates(lrByParam);
}
if (collectMetaData) {
report.reportDataSetMetaData(metaDataList, metaDataClass);
}
if (collectHistograms[0]) { //Param hist
report.reportHistograms(StatsType.Parameters, pHist);
}
if (collectHistograms[1]) { //Grad hist
report.reportHistograms(StatsType.Gradients, gHist);
}
if (collectHistograms[2]) { //Update hist
report.reportHistograms(StatsType.Updates, uHist);
}
if (collectHistograms[3]) { //Act hist
report.reportHistograms(StatsType.Activations, aHist);
}
if (collectMeanStdev[0]) { //Param mean/stdev
report.reportMean(StatsType.Parameters, pMean);
report.reportStdev(StatsType.Parameters, pStd);
}
if (collectMeanStdev[1]) { //Gradient mean/stdev
report.reportMean(StatsType.Gradients, gMean);
report.reportStdev(StatsType.Gradients, gStd);
}
if (collectMeanStdev[2]) { //Update mean/stdev
report.reportMean(StatsType.Updates, uMean);
report.reportStdev(StatsType.Updates, uStd);
}
if (collectMeanStdev[3]) { //Act mean/stdev
report.reportMean(StatsType.Activations, aMean);
report.reportStdev(StatsType.Activations, aStd);
}
if (collectMM[0]) { //Param mean mag
report.reportMeanMagnitudes(StatsType.Parameters, pMM);
}
if (collectMM[1]) { //Gradient mean mag
report.reportMeanMagnitudes(StatsType.Gradients, gMM);
}
if (collectMM[2]) { //Update mm
report.reportMeanMagnitudes(StatsType.Updates, uMM);
}
if (collectMM[3]) { //Act mm
report.reportMeanMagnitudes(StatsType.Activations, aMM);
}
byte[] bytes = report.encode();
StatsReport report2 = new SbeStatsReport();
report2.decode(bytes);
assertEquals(report, report2);
assertEquals(sessionID, report2.getSessionID());
assertEquals(typeID, report2.getTypeID());
assertEquals(workerID, report2.getWorkerID());
assertEquals(time, report2.getTimeStamp());
assertEquals(time, report2.getTimeStamp());
assertEquals(duration, report2.getStatsCollectionDurationMs());
assertEquals(iterCount, report2.getIterationCount());
if (collectPerformanceStats) {
assertEquals(perfRuntime, report2.getTotalRuntimeMs());
assertEquals(perfTotalEx, report2.getTotalExamples());
assertEquals(perfTotalMB, report2.getTotalMinibatches());
assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
Assert.assertTrue(report2.hasPerformance());
} else {
Assert.assertFalse(report2.hasPerformance());
}
if (collectMemoryStats) {
assertEquals(memJC, report2.getJvmCurrentBytes());
assertEquals(memJM, report2.getJvmMaxBytes());
assertEquals(memOC, report2.getOffHeapCurrentBytes());
assertEquals(memOM, report2.getOffHeapMaxBytes());
assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
assertArrayEquals(memDM, report2.getDeviceMaxBytes());
Assert.assertTrue(report2.hasMemoryUse());
} else {
Assert.assertFalse(report2.hasMemoryUse());
}
if (collectGCStats) {
List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
Assert.assertEquals(2, gcs.size());
Assert.assertEquals(gc1Name, gcs.get(0).getFirst());
Assert.assertArrayEquals(new int[] {gcdc1, gcdt1},
gcs.get(0).getSecond());
Assert.assertEquals(gc2Name, gcs.get(1).getFirst());
Assert.assertArrayEquals(new int[] {gcdc2, gcdt2},
gcs.get(1).getSecond());
Assert.assertTrue(report2.hasGarbageCollection());
} else {
Assert.assertFalse(report2.hasGarbageCollection());
}
if (collectScore) {
assertEquals(score, report2.getScore(), 0.0);
Assert.assertTrue(report2.hasScore());
} else {
Assert.assertFalse(report2.hasScore());
}
if (collectLearningRates) {
assertEquals(lrByParam.keySet(), report2.getLearningRates().keySet());
for (String s : lrByParam.keySet()) {
assertEquals(lrByParam.get(s), report2.getLearningRates().get(s),
1e-6);
}
Assert.assertTrue(report2.hasLearningRates());
} else {
Assert.assertFalse(report2.hasLearningRates());
}
if (collectMetaData) {
assertNotNull(report2.getDataSetMetaData());
assertEquals(metaDataList, report2.getDataSetMetaData());
assertEquals(metaDataClass.getName(),
report2.getDataSetMetaDataClassName());
assertTrue(report2.hasDataSetMetaData());
} else {
assertFalse(report2.hasDataSetMetaData());
}
if (collectHistograms[0]) {
assertEquals(pHist, report2.getHistograms(StatsType.Parameters));
Assert.assertTrue(report2.hasHistograms(StatsType.Parameters));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
}
if (collectHistograms[1]) {
assertEquals(gHist, report2.getHistograms(StatsType.Gradients));
Assert.assertTrue(report2.hasHistograms(StatsType.Gradients));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
}
if (collectHistograms[2]) {
assertEquals(uHist, report2.getHistograms(StatsType.Updates));
Assert.assertTrue(report2.hasHistograms(StatsType.Updates));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
}
if (collectHistograms[3]) {
assertEquals(aHist, report2.getHistograms(StatsType.Activations));
Assert.assertTrue(report2.hasHistograms(StatsType.Activations));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
}
if (collectMeanStdev[0]) {
assertEquals(pMean, report2.getMean(StatsType.Parameters));
assertEquals(pStd, report2.getStdev(StatsType.Parameters));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Stdev));
}
if (collectMeanStdev[1]) {
assertEquals(gMean, report2.getMean(StatsType.Gradients));
assertEquals(gStd, report2.getStdev(StatsType.Gradients));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Stdev));
}
if (collectMeanStdev[2]) {
assertEquals(uMean, report2.getMean(StatsType.Updates));
assertEquals(uStd, report2.getStdev(StatsType.Updates));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Stdev));
}
if (collectMeanStdev[3]) {
assertEquals(aMean, report2.getMean(StatsType.Activations));
assertEquals(aStd, report2.getStdev(StatsType.Activations));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Stdev));
}
if (collectMM[0]) {
assertEquals(pMM, report2.getMeanMagnitudes(StatsType.Parameters));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.MeanMagnitudes));
}
if (collectMM[1]) {
assertEquals(gMM, report2.getMeanMagnitudes(StatsType.Gradients));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.MeanMagnitudes));
}
if (collectMM[2]) {
assertEquals(uMM, report2.getMeanMagnitudes(StatsType.Updates));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates,
SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.MeanMagnitudes));
}
if (collectMM[3]) {
assertEquals(aMM, report2.getMeanMagnitudes(StatsType.Activations));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations,
SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.MeanMagnitudes));
}
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois =
new ObjectInputStream(new ByteArrayInputStream(javaBytes));
SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
assertEquals(report, report3);
testCount++;
}
}
}
}
}
}
}
}
}
Assert.assertEquals(13824, testCount);
}
@Test
public void testSbeStatsUpdateNullValues() throws Exception {
String[] paramNames = null; //new String[]{"param0", "param1"};
long time = System.currentTimeMillis();
int duration = 123456;
int iterCount = 123;
long perfRuntime = 1;
long perfTotalEx = 2;
long perfTotalMB = 3;
double perfEPS = 4.0;
double perfMBPS = 5.0;
long memJC = 6;
long memJM = 7;
long memOC = 8;
long memOM = 9;
long[] memDC = null;
long[] memDM = null;
String gc1Name = null;
int gcdc1 = 16;
int gcdt1 = 17;
String gc2Name = null;
int gcdc2 = 20;
int gcdt2 = 21;
double score = 22.0;
Map<String, Double> lrByParam = null;
Map<String, Histogram> pHist = null;
Map<String, Histogram> gHist = null;
Map<String, Histogram> uHist = null;
Map<String, Histogram> aHist = null;
Map<String, Double> pMean = null;
Map<String, Double> gMean = null;
Map<String, Double> uMean = null;
Map<String, Double> aMean = null;
Map<String, Double> pStd = null;
Map<String, Double> gStd = null;
Map<String, Double> uStd = null;
Map<String, Double> aStd = null;
Map<String, Double> pMM = null;
Map<String, Double> gMM = null;
Map<String, Double> uMM = null;
Map<String, Double> aMM = null;
boolean[] tf = new boolean[] {true, false};
boolean[][] tf4 = new boolean[][] {{false, false, false, false}, {true, false, false, false},
{false, true, false, false}, {false, false, true, false}, {false, false, false, true},
{true, true, true, true}};
//Total tests: 2^6 x 6^3 = 13,824 separate tests
int testCount = 0;
for (boolean collectPerformanceStats : tf) {
for (boolean collectMemoryStats : tf) {
for (boolean collectGCStats : tf) {
for (boolean collectDataSetMetaData : tf) {
for (boolean collectScore : tf) {
for (boolean collectLearningRates : tf) {
for (boolean[] collectHistograms : tf4) {
for (boolean[] collectMeanStdev : tf4) {
for (boolean[] collectMM : tf4) {
SbeStatsReport report = new SbeStatsReport();
report.reportIDs(null, null, null, time);
report.reportStatsCollectionDurationMS(duration);
report.reportIterationCount(iterCount);
if (collectPerformanceStats) {
report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS,
perfMBPS);
}
if (collectMemoryStats) {
report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
}
if (collectGCStats) {
report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
}
if (collectDataSetMetaData) {
//TODO
}
if (collectScore) {
report.reportScore(score);
}
if (collectLearningRates) {
report.reportLearningRates(lrByParam);
}
if (collectHistograms[0]) { //Param hist
report.reportHistograms(StatsType.Parameters, pHist);
}
if (collectHistograms[1]) {
report.reportHistograms(StatsType.Gradients, gHist);
}
if (collectHistograms[2]) { //Update hist
report.reportHistograms(StatsType.Updates, uHist);
}
if (collectHistograms[3]) { //Act hist
report.reportHistograms(StatsType.Activations, aHist);
}
if (collectMeanStdev[0]) { //Param mean/stdev
report.reportMean(StatsType.Parameters, pMean);
report.reportStdev(StatsType.Parameters, pStd);
}
if (collectMeanStdev[1]) { //Param mean/stdev
report.reportMean(StatsType.Gradients, gMean);
report.reportStdev(StatsType.Gradients, gStd);
}
if (collectMeanStdev[2]) { //Update mean/stdev
report.reportMean(StatsType.Updates, uMean);
report.reportStdev(StatsType.Updates, uStd);
}
if (collectMeanStdev[3]) { //Act mean/stdev
report.reportMean(StatsType.Activations, aMean);
report.reportStdev(StatsType.Activations, aStd);
}
if (collectMM[0]) { //Param mean mag
report.reportMeanMagnitudes(StatsType.Parameters, pMM);
}
if (collectMM[1]) { //Param mean mag
report.reportMeanMagnitudes(StatsType.Gradients, gMM);
}
if (collectMM[2]) { //Update mm
report.reportMeanMagnitudes(StatsType.Updates, uMM);
}
if (collectMM[3]) { //Act mm
report.reportMeanMagnitudes(StatsType.Activations, aMM);
}
byte[] bytes = report.encode();
StatsReport report2 = new SbeStatsReport();
report2.decode(bytes);
assertEquals(time, report2.getTimeStamp());
assertEquals(duration, report2.getStatsCollectionDurationMs());
assertEquals(iterCount, report2.getIterationCount());
if (collectPerformanceStats) {
assertEquals(perfRuntime, report2.getTotalRuntimeMs());
assertEquals(perfTotalEx, report2.getTotalExamples());
assertEquals(perfTotalMB, report2.getTotalMinibatches());
assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
Assert.assertTrue(report2.hasPerformance());
} else {
Assert.assertFalse(report2.hasPerformance());
}
if (collectMemoryStats) {
assertEquals(memJC, report2.getJvmCurrentBytes());
assertEquals(memJM, report2.getJvmMaxBytes());
assertEquals(memOC, report2.getOffHeapCurrentBytes());
assertEquals(memOM, report2.getOffHeapMaxBytes());
assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
assertArrayEquals(memDM, report2.getDeviceMaxBytes());
Assert.assertTrue(report2.hasMemoryUse());
} else {
Assert.assertFalse(report2.hasMemoryUse());
}
if (collectGCStats) {
List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
Assert.assertEquals(2, gcs.size());
assertNullOrZeroLength(gcs.get(0).getFirst());
Assert.assertArrayEquals(new int[] {gcdc1, gcdt1},
gcs.get(0).getSecond());
assertNullOrZeroLength(gcs.get(1).getFirst());
Assert.assertArrayEquals(new int[] {gcdc2, gcdt2},
gcs.get(1).getSecond());
Assert.assertTrue(report2.hasGarbageCollection());
} else {
Assert.assertFalse(report2.hasGarbageCollection());
}
if (collectDataSetMetaData) {
//TODO
}
if (collectScore) {
assertEquals(score, report2.getScore(), 0.0);
Assert.assertTrue(report2.hasScore());
} else {
Assert.assertFalse(report2.hasScore());
}
if (collectLearningRates) {
assertNull(report2.getLearningRates());
} else {
Assert.assertFalse(report2.hasLearningRates());
}
assertNull(report2.getHistograms(StatsType.Parameters));
Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
assertNull(report2.getHistograms(StatsType.Gradients));
Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
assertNull(report2.getHistograms(StatsType.Updates));
Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
assertNull(report2.getHistograms(StatsType.Activations));
Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
assertNull(report2.getMean(StatsType.Parameters));
assertNull(report2.getStdev(StatsType.Parameters));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Gradients));
assertNull(report2.getStdev(StatsType.Gradients));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Updates));
assertNull(report2.getStdev(StatsType.Updates));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Activations));
assertNull(report2.getStdev(StatsType.Activations));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.Stdev));
assertNull(report2.getMeanMagnitudes(StatsType.Parameters));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters,
SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Gradients));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients,
SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Updates));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates,
SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Activations));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations,
SummaryType.MeanMagnitudes));
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois =
new ObjectInputStream(new ByteArrayInputStream(javaBytes));
SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
assertEquals(report, report3);
testCount++;
}
}
}
}
}
}
}
}
}
Assert.assertEquals(13824, testCount);
}
}