/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE;
import static org.numenta.nupic.algorithms.Anomaly.KEY_USE_MOVING_AVG;
import static org.numenta.nupic.algorithms.Anomaly.KEY_WINDOW_SIZE;
import static org.numenta.nupic.network.NetworkTestHarness.*;
import java.io.File;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CyclicBarrier;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.junit.AfterClass;
import org.junit.Test;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.Anomaly.Mode;
import org.numenta.nupic.algorithms.AnomalyLikelihood;
import org.numenta.nupic.algorithms.AnomalyLikelihoodMetrics;
import org.numenta.nupic.algorithms.AnomalyLikelihoodTest;
import org.numenta.nupic.algorithms.CLAClassifier;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.algorithms.Sample;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.datagen.ResourceLocator;
import org.numenta.nupic.encoders.DateEncoder;
import org.numenta.nupic.encoders.MultiEncoder;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.SDR;
import org.numenta.nupic.network.Persistence.PersistenceAccess;
import org.numenta.nupic.network.sensor.FileSensor;
import org.numenta.nupic.network.sensor.HTMSensor;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.network.sensor.SensorParams.Keys;
import org.numenta.nupic.serialize.SerialConfig;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.FastRandom;
import org.numenta.nupic.util.MersenneTwister;
import org.numenta.nupic.util.Tuple;
import com.cedarsoftware.util.DeepEquals;
import gnu.trove.list.array.TIntArrayList;
import rx.Observer;
import rx.Subscriber;
import rx.observers.TestObserver;
public class PersistenceAPITest extends ObservableTestBase {
// TO TURN ON PRINTOUT: SET "TRUE" BELOW
/** Printer to visualize DayOfWeek printouts - SET TO TRUE FOR PRINTOUT */
private BiFunction<Inference, Integer, Integer> dayOfWeekPrintout = createDayOfWeekInferencePrintout(false);
@AfterClass
public static void cleanUp() {
System.out.println("cleaning up...");
try {
File serialDir = new File(System.getProperty("user.home") + File.separator + SerialConfig.SERIAL_TEST_DIR);
if(serialDir.exists()) {
Files.list(serialDir.toPath()).forEach(
f -> {
try { Files.deleteIfExists(f.toAbsolutePath()); }
catch(Exception io) { throw new RuntimeException(io); }
}
);
Files.delete(serialDir.toPath());
}
}catch(Exception e) {
e.printStackTrace();
}
}
@Test
public void testEnsurePathExists() {
SerialConfig config = new SerialConfig("testEnsurePathExists", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI persist = Persistence.get();
persist.setConfig(config);
try {
((PersistenceAccess)persist).ensurePathExists(config);
}catch(Exception e) { fail(); }
File f1 = new File(System.getProperty("user.home") + File.separator + config.getFileDir() + File.separator + "testEnsurePathExists");
assertTrue(f1.exists());
}
@Test
public void testSearchAndListPreviousCheckPoint() {
Parameters p = NetworkTestHarness.getParameters();
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())));
PersistenceAPI pa = Persistence.get(new SerialConfig(null, SerialConfig.SERIAL_TEST_DIR));
IntStream.range(0, 5).forEach(i -> ((PersistenceAccess)pa).getCheckPointFunction(network).apply(network));
List<String> checkPointFiles = pa.listCheckPointFiles();
assertTrue(checkPointFiles.size() > 4);
assertEquals(checkPointFiles.get(checkPointFiles.size() - 2),
pa.getPreviousCheckPoint(checkPointFiles.get(checkPointFiles.size() - 1)));
}
/////////////////////////////////////////////////////////////////////////////
// First, Test Serialization of Each (Critical) Object Individually //
/////////////////////////////////////////////////////////////////////////////
/////////////////////
// Parameters //
/////////////////////
@Test
public void testSerializeParameters() {
Parameters p = getParameters();
SerialConfig config = new SerialConfig("testSerializeParameters", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
// 1. serialize
byte[] data = api.write(p, "testSerializeParameters");
// 2. deserialize
Parameters serialized = api.read(data);
assertTrue(p.keys().size() == serialized.keys().size());
assertTrue(DeepEquals.deepEquals(p, serialized));
for(KEY k : p.keys()) {
deepCompare(serialized.get(k), p.get(k));
}
// 3. reify from file
/////////////////////////////////////
// SHOW RETRIEVAL USING FILENAME //
/////////////////////////////////////
Parameters fromFile = api.read("testSerializeParameters");
assertTrue(p.keys().size() == fromFile.keys().size());
assertTrue(DeepEquals.deepEquals(p, fromFile));
for(KEY k : p.keys()) {
deepCompare(fromFile.get(k), p.get(k));
}
}
/////////////////////
// Connections //
/////////////////////
@Test
public void testSerializeConnections() {
Parameters p = getParameters();
Connections con = new Connections();
p.apply(con);
TemporalMemory.init(con);
SerialConfig config = new SerialConfig("testSerializeConnections", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
// 1. serialize
byte[] data = api.write(con);
// 2. deserialize
Connections serialized = api.read(data);
assertTrue(DeepEquals.deepEquals(con, serialized));
serialized.printParameters();
int cellCount = con.getCellsPerColumn();
for(int i = 0;i < con.getNumColumns();i++) {
deepCompare(con.getColumn(i), serialized.getColumn(i));
for(int j = 0;j < cellCount;j++) {
Cell cell = serialized.getCell(i * cellCount + j);
deepCompare(con.getCell(i * cellCount + j), cell);
}
}
// 3. reify from file
Connections fromFile = api.read(data);
assertTrue(DeepEquals.deepEquals(con, fromFile));
for(int i = 0;i < con.getNumColumns();i++) {
deepCompare(con.getColumn(i), fromFile.getColumn(i));
for(int j = 0;j < cellCount;j++) {
Cell cell = fromFile.getCell(i * cellCount + j);
deepCompare(con.getCell(i * cellCount + j), cell);
}
}
}
// Connections with all types populated
// @SuppressWarnings("unused")
// @Test
// public void testMorePopulatedConnections() {
// TemporalMemory tm = new TemporalMemory();
// Connections cn = new Connections();
// cn.setConnectedPermanence(0.50);
// cn.setMinThreshold(1);
// // Init with default params defined in Connections.java default fields.
// tm.init(cn);
//
// SerialConfig config = new SerialConfig("testSerializeConnections2", SerialConfig.SERIAL_TEST_DIR);
// PersistenceAPI api = Persistence.get(config);
//
// DistalDendrite dd = cn.getCell(0).createSegment(cn);
// Synapse s0 = dd.createSynapse(cn, cn.getCell(23), 0.6);
// Synapse s1 = dd.createSynapse(cn, cn.getCell(37), 0.4);
// Synapse s2 = dd.createSynapse(cn, cn.getCell(477), 0.9);
//
// byte[] dda = api.write(dd);
// DistalDendrite ddo = api.read(dda);
// deepCompare(dd, ddo);
// List<Synapse> l1 = dd.getAllSynapses(cn);
// List<Synapse> l2 = ddo.getAllSynapses(cn);
// assertTrue(l2.equals(l1));
//
// DistalDendrite dd1 = cn.getCell(0).createSegment(cn);
// Synapse s3 = dd1.createSynapse(cn, cn.getCell(49), 0.9);
// Synapse s4 = dd1.createSynapse(cn, cn.getCell(3), 0.8);
//
// DistalDendrite dd2 = cn.getCell(1).createSegment(cn);
// Synapse s5 = dd2.createSynapse(cn, cn.getCell(733), 0.7);
//
// DistalDendrite dd3 = cn.getCell(8).createSegment(cn);
// Synapse s6 = dd3.createSynapse(cn, cn.getCell(486), 0.9);
//
//
// Connections cn2 = new Connections();
// cn2.setConnectedPermanence(0.50);
// cn2.setMinThreshold(1);
// tm.init(cn2);
//
// DistalDendrite ddb = cn2.getCell(0).createSegment(cn2);
// Synapse s0b = ddb.createSynapse(cn2, cn2.getCell(23), 0.6);
// Synapse s1b = ddb.createSynapse(cn2, cn2.getCell(37), 0.4);
// Synapse s2b = ddb.createSynapse(cn2, cn2.getCell(477), 0.9);
//
// DistalDendrite dd1b = cn2.getCell(0).createSegment(cn2);
// Synapse s3b = dd1b.createSynapse(cn2, cn2.getCell(49), 0.9);
// Synapse s4b = dd1b.createSynapse(cn2, cn2.getCell(3), 0.8);
//
// DistalDendrite dd2b = cn2.getCell(1).createSegment(cn2);
// Synapse s5b = dd2b.createSynapse(cn2, cn2.getCell(733), 0.7);
//
// DistalDendrite dd3b = cn2.getCell(8).createSegment(cn2);
// Synapse s6b = dd3b.createSynapse(cn2, cn2.getCell(486), 0.9);
//
// assertTrue(cn.equals(cn2));
//
// Set<Cell> activeCells = cn.getCellSet(new int[] { 733, 37, 974, 23 });
//
// SegmentSearch result = tm.getBestMatchingSegment(cn, cn.getCell(0), activeCells);
// assertEquals(dd, result.bestSegment);
// assertEquals(2, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(1), activeCells);
// assertEquals(dd2, result.bestSegment);
// assertEquals(1, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(8), activeCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(100), activeCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// //Test that we can repeat this
// result = tm.getBestMatchingSegment(cn, cn.getCell(0), activeCells);
// assertEquals(dd, result.bestSegment);
// assertEquals(2, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(1), activeCells);
// assertEquals(dd2, result.bestSegment);
// assertEquals(1, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(8), activeCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(cn, cn.getCell(100), activeCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// // 1. serialize
// byte[] data = api.write(cn, "testSerializeConnections2");
//
// // 2. deserialize
// Connections serialized = api.read(data);
//
// Set<Cell> serialActiveCells = serialized.getCellSet(new int[] { 733, 37, 974, 23 });
//
// deepCompare(activeCells, serialActiveCells);
//
// result = tm.getBestMatchingSegment(serialized, serialized.getCell(0), serialActiveCells);
// assertEquals(dd, result.bestSegment);
// assertEquals(2, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(serialized, serialized.getCell(1), serialActiveCells);
// assertEquals(dd2, result.bestSegment);
// assertEquals(1, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(serialized, serialized.getCell(8), serialActiveCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// result = tm.getBestMatchingSegment(serialized, serialized.getCell(100), serialActiveCells);
// assertEquals(null, result.bestSegment);
// assertEquals(0, result.numActiveSynapses);
//
// boolean b = DeepEquals.deepEquals(cn, serialized);
// deepCompare(cn, serialized);
// assertTrue(b);
//
// //{0=[synapse: [ synIdx=0, inIdx=23, sgmtIdx=0, srcCellIdx=23 ], synapse: [ synIdx=1, inIdx=37, sgmtIdx=0, srcCellIdx=37 ], synapse: [ synIdx=2, inIdx=477, sgmtIdx=0, srcCellIdx=477 ]], 1=[synapse: [ synIdx=3, inIdx=49, sgmtIdx=1, srcCellIdx=49 ], synapse: [ synIdx=4, inIdx=3, sgmtIdx=1, srcCellIdx=3 ]], 2=[synapse: [ synIdx=5, inIdx=733, sgmtIdx=2, srcCellIdx=733 ]], 3=[synapse: [ synIdx=6, inIdx=486, sgmtIdx=3, srcCellIdx=486 ]]}
// //{0=[synapse: [ synIdx=0, inIdx=23, sgmtIdx=0, srcCellIdx=23 ], synapse: [ synIdx=1, inIdx=37, sgmtIdx=0, srcCellIdx=37 ], synapse: [ synIdx=2, inIdx=477, sgmtIdx=0, srcCellIdx=477 ]], 1=[synapse: [ synIdx=3, inIdx=49, sgmtIdx=1, srcCellIdx=49 ], synapse: [ synIdx=4, inIdx=3, sgmtIdx=1, srcCellIdx=3 ]], 2=[synapse: [ synIdx=5, inIdx=733, sgmtIdx=2, srcCellIdx=733 ]], 3=[synapse: [ synIdx=6, inIdx=486, sgmtIdx=3, srcCellIdx=486 ]]}
//
// }
// Test Connections Serialization after running through TemporalMemory
@Test
public void testThreadedPublisher_TemporalMemoryNetwork() {
Network network = createAndRunTestTemporalMemoryNetwork();
Layer<?> l = network.lookup("r1").lookup("1");
Connections cn = l.getConnections();
SerialConfig config = new SerialConfig("testThreadedPublisher_TemporalMemoryNetwork", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(cn);
Connections serializedConnections = api.read(bytes);
Network network2 = createAndRunTestTemporalMemoryNetwork();
Layer<?> l2 = network2.lookup("r1").lookup("1");
Connections newCons = l2.getConnections();
boolean b = DeepEquals.deepEquals(newCons, serializedConnections);
deepCompare(newCons, serializedConnections);
assertTrue(b);
}
// Test Connections Serialization after running through SpatialPooler
@Test
public void testThreadedPublisher_SpatialPoolerNetwork() {
Network network = createAndRunTestSpatialPoolerNetwork(0, 6);
Layer<?> l = network.lookup("r1").lookup("1");
Connections cn = l.getConnections();
SerialConfig config = new SerialConfig("testThreadedPublisher_SpatialPoolerNetwork", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(cn);
//Serialize above Connections for comparison with same run but unserialized below...
Connections serializedConnections = api.read(bytes);
Network network2 = createAndRunTestSpatialPoolerNetwork(0, 6);
Layer<?> l2 = network2.lookup("r1").lookup("1");
Connections newCons = l2.getConnections();
//Compare the two Connections (both serialized and regular runs) - should be equal
boolean b = DeepEquals.deepEquals(newCons, serializedConnections);
deepCompare(newCons, serializedConnections);
assertTrue(b);
}
/////////////////////////// End Connections Serialization Testing //////////////////////////////////
/////////////////////
// HTMSensor //
/////////////////////
// Serialize HTMSensors though they'll probably be reconstituted rather than serialized
@Test
public void testHTMSensor_DaysOfWeek() {
Object[] n = { "some name", ResourceLocator.path("days-of-week.csv") };
HTMSensor<File> sensor = (HTMSensor<File>)Sensor.create(
FileSensor::create, SensorParams.create(Keys::path, n));
Parameters p = getParameters();
p = p.union(NetworkTestHarness.getDayDemoTestEncoderParams());
sensor.initEncoder(p);
SerialConfig config = new SerialConfig("testHTMSensor_DaysOfWeek", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(sensor);
HTMSensor<File> serializedSensor = api.read(bytes);
boolean b = DeepEquals.deepEquals(serializedSensor, sensor);
deepCompare(serializedSensor, sensor);
assertTrue(b);
}
@Test
public void testHTMSensor_HotGym() {
Object[] n = { "some name", ResourceLocator.path("rec-center-hourly-small.csv") };
HTMSensor<File> sensor = (HTMSensor<File>)Sensor.create(
FileSensor::create, SensorParams.create(Keys::path, n));
sensor.initEncoder(getTestEncoderParams());
SerialConfig config = new SerialConfig("testHTMSensor_HotGym");
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(sensor);
assertNotNull(bytes);
assertTrue(bytes.length > 0);
HTMSensor<File> serializedSensor = api.read(bytes);
boolean b = DeepEquals.deepEquals(serializedSensor, sensor);
deepCompare(serializedSensor, sensor);
assertTrue(b);
}
@Test
public void testSerializeObservableSensor() {
PublisherSupplier supplier = PublisherSupplier.builder()
.addHeader("dayOfWeek")
.addHeader("darr")
.addHeader("B").build();
ObservableSensor<String[]> oSensor = new ObservableSensor<>(SensorParams.create(Keys::obs, new Object[] {"name", supplier}));
SerialConfig config = new SerialConfig("testSerializeObservableSensor", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(oSensor);
ObservableSensor<String[]> serializedOSensor = api.read(bytes);
boolean b = DeepEquals.deepEquals(serializedOSensor, oSensor);
deepCompare(serializedOSensor, oSensor);
assertTrue(b);
}
//////////////////////////////////End HTMSensors ////////////////////////////////////
/////////////////////
// Anomaly //
/////////////////////
// Serialize Anomaly, AnomalyLikelihood and its support classes
@Test
public void testSerializeAnomaly() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.PURE);
Anomaly anomalyComputer = Anomaly.create(params);
// Serialize the Anomaly Computer without errors
SerialConfig config = new SerialConfig("testSerializeAnomaly1", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(anomalyComputer);
// Deserialize the Anomaly Computer and make sure its usable (same tests as AnomalyTest.java)
Anomaly serializedAnomalyComputer = api.read(bytes);
double score = serializedAnomalyComputer.compute(new int[0], new int[0], 0, 0);
assertEquals(0.0, score, 0);
score = serializedAnomalyComputer.compute(new int[0], new int[] {3,5}, 0, 0);
assertEquals(0.0, score, 0);
score = serializedAnomalyComputer.compute(new int[] { 3, 5, 7 }, new int[] { 3, 5, 7 }, 0, 0);
assertEquals(0.0, score, 0);
score = serializedAnomalyComputer.compute(new int[] { 2, 3, 6 }, new int[] { 3, 5, 7 }, 0, 0);
assertEquals(2.0 / 3.0, score, 0);
}
@Test
public void testSerializeCumulativeAnomaly() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.PURE);
params.put(KEY_WINDOW_SIZE, 3);
params.put(KEY_USE_MOVING_AVG, true);
Anomaly anomalyComputer = Anomaly.create(params);
// Serialize the Anomaly Computer without errors
SerialConfig config = new SerialConfig("testSerializeCumulativeAnomaly", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(anomalyComputer);
// Deserialize the Anomaly Computer and make sure its usable (same tests as AnomalyTest.java)
Anomaly serializedAnomalyComputer = api.read(bytes);
assertNotNull(serializedAnomalyComputer);
Object[] predicted = {
new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 },
new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 },
new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }
};
Object[] actual = {
new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }, new int[] { 1, 4, 6 },
new int[] { 10, 11, 6 }, new int[] { 10, 11, 12 }, new int[] { 10, 11, 12 },
new int[] { 10, 11, 12 }, new int[] { 1, 2, 6 }, new int[] { 1, 2, 6 }
};
double[] anomalyExpected = { 0.0, 0.0, 1.0/9.0, 3.0/9.0, 2.0/3.0, 8.0/9.0, 1.0, 2.0/3.0, 1.0/3.0 };
for(int i = 0;i < 9;i++) {
double score = serializedAnomalyComputer.compute((int[])actual[i], (int[])predicted[i], 0, 0);
assertEquals(anomalyExpected[i], score, 0.01);
}
}
@Test
public void testSerializeAnomalyLikelihood() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.LIKELIHOOD);
AnomalyLikelihood an = (AnomalyLikelihood)Anomaly.create(params);
// Serialize the Anomaly Computer without errors
SerialConfig config = new SerialConfig("testSerializeAnomalyLikelihood", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(an);
// Deserialize the Anomaly Computer and make sure its usable (same tests as AnomalyTest.java)
Anomaly serializedAn = api.read(bytes);
assertNotNull(serializedAn);
}
@Test
public void testSerializeAnomalyLikelihoodForUpdates() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.LIKELIHOOD);
AnomalyLikelihood an = (AnomalyLikelihood)Anomaly.create(params);
// Serialize the Anomaly Computer without errors
SerialConfig config = new SerialConfig("testSerializeAnomalyLikelihood", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] bytes = api.write(an);
// Deserialize the Anomaly Computer and make sure its usable (same tests as AnomalyTest.java)
AnomalyLikelihood serializedAn = api.read(bytes);
assertNotNull(serializedAn);
//----------------------------------------
// Step 1. Generate an initial estimate using fake distribution of anomaly scores.
List<Sample> data1 = AnomalyLikelihoodTest.generateSampleData(0.2, 0.2, 0.2, 0.2).subList(0, 1000);
AnomalyLikelihoodMetrics metrics1 = serializedAn.estimateAnomalyLikelihoods(data1, 5, 0);
//----------------------------------------
// Step 2. Generate some new data with a higher average anomaly
// score. Using the estimator from step 1, to compute likelihoods. Now we
// should see a lot more anomalies.
List<Sample> data2 = AnomalyLikelihoodTest.generateSampleData(0.6, 0.2, 0.2, 0.2).subList(0, 300);
AnomalyLikelihoodMetrics metrics2 = serializedAn.updateAnomalyLikelihoods(data2, metrics1.getParams());
// Serialize the Metrics too just to be sure everything can be serialized
SerialConfig metricsConfig = new SerialConfig("testSerializeMetrics", SerialConfig.SERIAL_TEST_DIR);
api = Persistence.get(metricsConfig);
api.write(metrics2);
// Deserialize the Metrics
AnomalyLikelihoodMetrics serializedMetrics = api.read();
assertNotNull(serializedMetrics);
assertEquals(serializedMetrics.getLikelihoods().length, data2.size());
assertEquals(serializedMetrics.getAvgRecordList().size(), data2.size());
assertTrue(serializedAn.isValidEstimatorParams(serializedMetrics.getParams()));
// The new running total should be different
assertFalse(metrics1.getAvgRecordList().total == serializedMetrics.getAvgRecordList().total);
// We should have many more samples where likelihood is < 0.01, but not all
Condition<Double> cond = new Condition.Adapter<Double>() {
public boolean eval(double d) { return d < 0.01; }
};
int conditionCount = ArrayUtils.where(serializedMetrics.getLikelihoods(), cond).length;
assertTrue(conditionCount >= 25);
assertTrue(conditionCount <= 250);
}
/////////////////////// End Serialize Anomaly //////////////////////////
///////////////////////////
// CLAClassifier //
///////////////////////////
// Test Serialize CLAClassifier
@Test
public void testSerializeCLAClassifier() {
CLAClassifier classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 41.7);
result = classifier.compute(recordNum, classification, new int[] { 0, 6, 9, 11 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 44.9);
result = classifier.compute(recordNum, classification, new int[] { 6, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 42.9);
result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
// Serialize the Metrics too just to be sure everything can be serialized
SerialConfig config = new SerialConfig("testSerializeCLAClassifier", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.write(classifier);
// Deserialize the Metrics
CLAClassifier serializedClassifier = api.read();
assertNotNull(serializedClassifier);
//Using the deserialized classifier, continue test
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
result = serializedClassifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(35.520000457763672, result.getActualValue(4), 0.00001);
assertEquals(42.020000457763672, result.getActualValue(5), 0.00001);
assertEquals(6, result.getStatCount(1));
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(0.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
assertEquals(0.0, result.getStat(1, 3), 0.00001);
assertEquals(0.12300123, result.getStat(1, 4), 0.00001);
assertEquals(0.87699877, result.getStat(1, 5), 0.00001);
}
//////////////////////// End CLAClassifier ///////////////////////
///////////////////////////
// Layers //
///////////////////////////
// Serialize a Layer
@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
public void testSerializeLayer() {
Parameters p = NetworkTestHarness.getParameters().copy();
p.set(KEY.RANDOM, new MersenneTwister(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("dayOfWeek", CLAClassifier.class));
Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap(
null, // map
8, // n
0, // w
0, // min
0, // max
0, // radius
0, // resolution
null, // periodic
null, // clip
Boolean.TRUE, // forced
"dayOfWeek", // fieldName
"darr", // fieldType (dense array as opposed to sparse array or "sarr")
"SDRPassThroughEncoder"); // encoderType
p.set(KEY.FIELD_ENCODING_MAP, settings);
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("dayOfWeek")
.addHeader("darr")
.addHeader("B").build() }));
Layer<?> layer = Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(new SpatialPooler())
.add(sensor);
Observer obs = new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference spatialPoolerOutput) {
System.out.println("in onNext()");
}
};
layer.subscribe(obs);
layer.close();
SerialConfig config = new SerialConfig("testSerializeLayer", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.write(layer);
//Serialize above Connections for comparison with same run but unserialized below...
Layer<?> serializedLayer = api.read();
assertEquals(serializedLayer, layer);
deepCompare(layer, serializedLayer);
// Now change one attribute and see that they are not equal
serializedLayer.resetRecordNum();
assertNotEquals(serializedLayer, layer);
}
////////////////////// End Layers ///////////////////////
///////////////////////////
// Full Network //
///////////////////////////
@Test
public void testHierarchicalNetwork() {
Network network = getLoadedHotGymHierarchy();
try {
SerialConfig config = new SerialConfig("testSerializeHierarchy", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.store(network);
}catch(Exception e) {
e.printStackTrace();
fail();
}
}
/**
* Test that a serialized/de-serialized {@link Network} can be run...
*/
@Test
public void testSerializedUnStartedNetworkRuns() {
final int NUM_CYCLES = 600;
final int INPUT_GROUP_COUNT = 7; // Days of Week
Network network = getLoadedDayOfWeekNetwork();
SerialConfig config = new SerialConfig("testSerializedUnStartedNetworkRuns", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.store(network);
//Serialize above Connections for comparison with same run but unserialized below...
Network serializedNetwork = api.load();
assertEquals(serializedNetwork, network);
deepCompare(network, serializedNetwork);
int cellsPerCol = (int)serializedNetwork.getParameters().get(KEY.CELLS_PER_COLUMN);
serializedNetwork.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
/** see {@link #createDayOfWeekInferencePrintout()} */
dayOfWeekPrintout.apply(inf, cellsPerCol);
}
});
Publisher pub = serializedNetwork.getPublisher();
serializedNetwork.start();
int cycleCount = 0;
for(;cycleCount < NUM_CYCLES;cycleCount++) {
for(double j = 0;j < INPUT_GROUP_COUNT;j++) {
pub.onNext("" + j);
}
serializedNetwork.reset();
if(cycleCount == 284) {
break;
}
}
pub.onComplete();
try {
Region r1 = serializedNetwork.lookup("r1");
r1.lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
}
/**
* The {@link DateEncoder} presents a special challenge because its main
* field (the DateFormatter) is not serializable and requires its state (format)
* to be saved and re-installed following de-serialization.
*/
@Test
public void testSerializedUnStartedNetworkRuns_DateEncoder() {
final int NUM_CYCLES = 100;
Network network = getLoadedHotGymNetwork();
SerialConfig config = new SerialConfig("testSerializedUnStartedNetworkRuns_DateEncoderFST",
SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.store(network);
//Serialize above Connections for comparison with same run but unserialized below...
Network serializedNetwork = api.load();
assertEquals(serializedNetwork, network);
deepCompare(network, serializedNetwork);
TestObserver<Inference> tester;
serializedNetwork.observe().subscribe(tester = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
assertNotNull(inf);
}
});
Publisher pub = serializedNetwork.getPublisher();
serializedNetwork.start();
int cycleCount = 0;
List<String> hotStream = makeStream().collect(Collectors.toList());
for(;cycleCount < NUM_CYCLES;cycleCount++) {
for(String s : hotStream) {
pub.onNext(s);
}
serializedNetwork.reset();
}
pub.onComplete();
try {
Region r1 = serializedNetwork.lookup("r1");
r1.lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
// Check for exception during the TestObserver's onNext() execution.
checkObserver(tester);
}
/**
* Runs two de-serialized {@link Networks} in lock-step using a {@link CyclicBarrier},
* checking that their RNG seeds are exactly the same after each compute cycle - which
* guarantees that the output of both Networks is the same (the RNGs must be called
* exactly the same amount of times for this to be true).
*/
long[] barrierSeeds = new long[2];
CyclicBarrier barrier;
int runCycleCount = 0;
@Test
public void testDeserializedInstancesRunExactlyTheSame() {
final int NUM_CYCLES = 100;
List<String> hotStream = null;
Network network = getLoadedHotGymNetwork();
SerialConfig config = new SerialConfig("testDeserializedInstancesRunExactlyTheSame", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.store(network);
Network serializedNetwork1 = api.load();
Network serializedNetwork2 = api.load();
FastRandom r1 = (FastRandom)serializedNetwork1.lookup("r1").lookup("1").getConnections().getRandom();
FastRandom r2 = (FastRandom)serializedNetwork2.lookup("r1").lookup("1").getConnections().getRandom();
// Assert both starting seeds are equal
assertEquals(r1.getSeed(), r2.getSeed());
// CyclicBarrier which compares each Network's RNG after every compute cycle, and asserts they are equal.
barrier = new CyclicBarrier(3, () -> {
try {
assertEquals(barrierSeeds[0], barrierSeeds[1]);
}catch(Exception barrierEx) {
System.out.println("Seed comparison failed at: " + runCycleCount);
System.exit(1);
}
});
serializedNetwork1.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
barrierSeeds[0] = r1.getSeed();
try { barrier.await(); }catch(Exception b) { b.printStackTrace(); System.exit(1);}
}
});
serializedNetwork2.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
barrierSeeds[1] = r2.getSeed();
try { barrier.await(); }catch(Exception b) { b.printStackTrace(); }
}
});
Publisher pub1 = serializedNetwork1.getPublisher();
Publisher pub2 = serializedNetwork2.getPublisher();
serializedNetwork1.start();
serializedNetwork2.start();
runCycleCount = 0;
hotStream = makeStream().collect(Collectors.toList());
for(;runCycleCount < NUM_CYCLES;runCycleCount++) {
for(String s : hotStream) {
pub1.onNext(s);
pub2.onNext(s);
try { barrier.await(); }catch(Exception b) { b.printStackTrace(); fail(); }
}
}
pub1.onComplete();
pub2.onComplete();
try {
serializedNetwork1.lookup("r1").lookup("1").getLayerThread().join();
serializedNetwork2.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
}
/**
* Ensure that a Network run uninterrupted will have the same output as a {@link Network}
* that has been halted; serialized, and restarted! This test runs a {@link Network}
* all the way through, recording 10 outputs between set indexes. Then runs a second
* Network, stopping in the middle of those indexes and serializing the Network. Then
* de-serializes that Network, and continues on - both pre and post serialized Networks
* record the same indexes as the first Network that ran all the way through. The outputs
* of both Networks from the start to finish indexes are compared and tested that they
* are exactly the same.
*/
@Test
public void testRunSerializedNetworkWithFileSensor() {
// Stores the sample comparison outputs at the indicated record numbers.
List<String> sampleExpectedOutput = new ArrayList<>(10);
// Run the network all the way, while storing a sample of 10 outputs.
Network net = getLoadedHotGymNetwork_FileSensor();
net.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
if(inf.getRecordNum() > 1105 && inf.getRecordNum() <= 1115) {
sampleExpectedOutput.add("" + inf.getRecordNum() + ": " + Arrays.toString((int[])inf.getLayerInput()) + ", " + inf.getAnomalyScore());
}
}
});
net.start();
try {
net.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
// Now run the network part way through, halting in between the save points above
Network network = getLoadedHotGymNetwork_FileSensor();
// Store the actual outputs and the same record number indexes for comparison across pre and post serialized networks.
List<String> actualOutputs = new ArrayList<>();
network.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
if(inf.getRecordNum() > 1105 && inf.getRecordNum() <= 1115) {
actualOutputs.add("" + inf.getRecordNum() + ": " + Arrays.toString((int[])inf.getLayerInput()) + ", " + inf.getAnomalyScore());
}
if(inf.getRecordNum() == 1109) {
network.halt();
}
}
});
network.start();
try {
network.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
SerialConfig config = new SerialConfig("testRunSerializedNetworkWithFileSensor", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
api.store(network);
/////////////////////////////////////////////////////////
// Now run the serialized network
Network serializedNetwork = api.load();
serializedNetwork.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
if(inf.getRecordNum() > 1105 && inf.getRecordNum() <= 1115) {
actualOutputs.add("" + inf.getRecordNum() + ": " + Arrays.toString((int[])inf.getLayerInput()) + ", " + inf.getAnomalyScore());
}
}
});
serializedNetwork.start();
try {
serializedNetwork.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
assertEquals(sampleExpectedOutput.size(), actualOutputs.size());
assertTrue(DeepEquals.deepEquals(sampleExpectedOutput, actualOutputs));
}
/**
* Tests the Network Serialization API
*/
@Test
public void testStoreAndLoad_FileSensor() {
Network network = getLoadedHotGymNetwork_FileSensor();
PersistenceAPI api = Persistence.get();
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// System.out.println("" + inf.getRecordNum() + ", " + inf.getAnomalyScore());
if(inf.getRecordNum() == 500) {
/////////////////////////////////
// Network Store Here //
/////////////////////////////////
api.store(network);
}
}
});
network.start();
try {
network.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
Network serializedNetwork = api.load();
TestObserver<Inference> tester2;
serializedNetwork.observe().subscribe(tester2 = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// System.out.println("1: " + inf.getRecordNum() + ", " + inf.getAnomalyScore());
assertEquals(501, inf.getRecordNum());
if(inf.getRecordNum() == 501) {
serializedNetwork.halt();
// System.out.println("should not see output after this line");
}else{
fail();
}
}
});
serializedNetwork.restart();
try {
serializedNetwork.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
// Test that we can start the Network from the beginning of the stream.
Network serializedNetwork2 = api.load();
TestObserver<Inference> tester3;
serializedNetwork2.observe().subscribe(tester3 = new TestObserver<Inference>() {
int idx = 0;
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// System.out.println("2: " + inf.getRecordNum() + ", " + inf.getAnomalyScore());
if(idx != inf.getRecordNum()) {
fail();
if(idx == 500) serializedNetwork2.halt();
}
++idx;
}
});
serializedNetwork2.restart(false);
try {
serializedNetwork2.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
checkObserver(tester);
checkObserver(tester2);
checkObserver(tester3);
}
/**
* This test stops and starts a Network which has a Publisher - which is not the way
* it would be done in production as you could just stop entering data, and continue
* again as desired. The user controls the data flow when using a Publisher.
*/
@Test
public void testStoreAndLoad_ObservableSensor() {
Network network = getLoadedHotGymNetwork();
PersistenceAPI api = Persistence.get();
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// System.out.println("" + inf.getRecordNum() + ", " + inf.getAnomalyScore() + (inf.getRecordNum() == 0 ? Arrays.toString((int[])inf.getLayerInput()) : ""));
if(inf.getRecordNum() == 499) {
/////////////////////////////////
// Network Store Here //
/////////////////////////////////
api.store(network);
}
}
});
network.start();
Publisher publisher = network.getPublisher();
List<String> hotStream = makeStream().collect(Collectors.toList());
int numRecords = 0;
boolean done = false;
while(true) {
for(String s : hotStream) {
publisher.onNext(s);
numRecords++;
if(numRecords == 500) {
done = true;
break;
}
}
if(done) break;
}
try {
network.lookup("r1").lookup("1").getLayerThread().join();
// System.out.println("------------------> buffer size = " + publisher.getBufferSize());
// System.out.println("NETWORK TEST RECORD NUM: " + network.getRecordNum());
}catch(Exception e) {
e.printStackTrace();
}
Network serializedNetwork = api.load();
TestObserver<Inference> tester2;
serializedNetwork.observe().subscribe(tester2 = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// System.out.println("1: " + inf.getRecordNum() + ", " + inf.getAnomalyScore());
if(inf.getRecordNum() == 500) {
assertEquals(500, inf.getRecordNum());
serializedNetwork.halt();
}else{
fail();
}
}
});
boolean startAtIndex = true;
serializedNetwork.restart(startAtIndex);
// IMPORTANT: Re-acquire the publisher after restart! (Can't use old one)
publisher = serializedNetwork.getPublisher();
for(String s : hotStream) {
publisher.onNext(s);
numRecords++;
if(numRecords > 500) {
break;
}
}
try {
serializedNetwork.lookup("r1").lookup("1").getLayerThread().join(5000);
}catch(Exception e) {
e.printStackTrace();
}
// Test that we can start the Network from the beginning of the stream.
Network serializedNetwork2 = api.load();
TestObserver<Inference> tester3;
serializedNetwork2.observe().subscribe(tester3 = new TestObserver<Inference>() {
int idx = 0;
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
// System.out.println("2: " + inf.getRecordNum() + ", " + inf.getAnomalyScore());
if(idx != inf.getRecordNum()) {
fail();
}
assertTrue(idx == 0 && idx == inf.getRecordNum());
}
});
startAtIndex = false;
serializedNetwork2.restart(startAtIndex);
// IMPORTANT: Re-acquire the publisher after restart! (Can't use old one)
publisher = serializedNetwork2.getPublisher();
publisher.onNext(hotStream.get(0));
try {
serializedNetwork2.lookup("r1").lookup("1").getLayerThread().join(5000);
}catch(Exception e) {
e.printStackTrace();
}
checkObserver(tester);
checkObserver(tester2);
checkObserver(tester3);
}
@SuppressWarnings("unchecked")
@Test
public void testStoreAndLoad_SynchronousNetwork() {
Network network = getLoadedHotGymSynchronousNetwork();
PersistenceAPI api = Persistence.get();
Map<String, Map<String, Object>> fieldEncodingMap =
(Map<String, Map<String, Object>>)network.getParameters().get(KEY.FIELD_ENCODING_MAP);
MultiEncoder me = MultiEncoder.builder()
.name("")
.build()
.addMultipleEncoders(fieldEncodingMap);
network.lookup("r1").lookup("1").add(me);
// We just use this to parse the date field
DateEncoder dateEncoder = me.getEncoderOfType(FieldMetaType.DATETIME);
Map<String, Object> m = new HashMap<>();
List<String> l = makeStream().collect(Collectors.toList());
for(int j = 0;j < 50;j++) {
for(int i = 0;i < 20;i++) {
String[] sa = l.get(i).split("[\\s]*\\,[\\s]*");
m.put("timestamp", dateEncoder.parse(sa[0]));
m.put("consumption", Double.parseDouble(sa[1]));
// System.out.println(m);
// Inference inf = network.computeImmediate(m);
network.computeImmediate(m);
// System.out.println("" + inf.getRecordNum() + ", " + inf.getAnomalyScore());
}
network.reset();
}
//////////////////////////////////////
// Store the Network //
//////////////////////////////////////
api.store(network);
//////////////////////////////////////
// Reload the Network //
//////////////////////////////////////
Network serializedNetwork = api.load();
boolean serializedNetworkRan = false;
// Pump data through the serialized Network
for(int j = 0;j < 50;j++) {
for(int i = 0;i < 20;i++) {
String[] sa = l.get(i).split("[\\s]*\\,[\\s]*");
m.put("timestamp", dateEncoder.parse(sa[0]));
m.put("consumption", Double.parseDouble(sa[1]));
// System.out.println(m);
Inference inf = serializedNetwork.computeImmediate(m);
serializedNetworkRan = inf.getRecordNum() > 0;
// System.out.println("2: " + inf.getRecordNum() + ", " + inf.getAnomalyScore() + ", " + Arrays.toString((int[])inf.getEncoding()));
}
serializedNetwork.reset();
}
assertTrue(serializedNetwork != null);
assertTrue(serializedNetworkRan);
}
TestObserver<byte[]> nestedTester;
@Test
public void testCheckpoint_FileSensor() {
Network network = getLoadedHotGymNetwork_FileSensor();
PersistenceAPI api = Persistence.get();
SerialConfig config = api.getConfig();
config.setOneCheckPointOnly(false);
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
if(inf.getRecordNum() == 500 || inf.getRecordNum() == 750) {
/////////////////////////////////
// Network Store Here //
/////////////////////////////////
api.checkPointer(network).checkPoint(nestedTester = new TestObserver<byte[]>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(byte[] bytes) {
assertTrue(bytes != null && bytes.length > 10);
}
});
}else if(inf.getRecordNum() == 1000) {
network.halt();
}
}
});
network.start();
try {
network.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
assertTrue(api.getLastCheckPoint() != null);
/////////////////////// Now test the checkpointed Network /////////////////////////
//////////////////////////////////////
// CheckPoint the Network //
//////////////////////////////////////
Network checkPointNetwork = null;
try {
checkPointNetwork = api.load(api.getLastCheckPointFileName());
assertNotNull(checkPointNetwork);
}catch(Exception e) {
e.printStackTrace();
fail();
}
TestObserver<Inference> tester2;
final Network cpn = checkPointNetwork;
checkPointNetwork.observe().subscribe(tester2 = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// Assert that the records continue from where the checkpoint left off.
assertEquals(752, inf.getRecordNum());
cpn.halt();
}
});
checkPointNetwork.restart();
try {
checkPointNetwork.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
checkObserver(tester);
checkObserver(nestedTester);
checkObserver(tester2);
}
TestObserver<byte[]> nestedTester2;
@Test
public void testCheckpoint_ObservableSensor() {
Network network = getLoadedHotGymNetwork();
PersistenceAPI api = Persistence.get();
SerialConfig conf = api.getConfig();
assertNotNull(conf);
conf.setOneCheckPointOnly(false);
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
if(inf.getRecordNum() == 500 || inf.getRecordNum() == 750) {
/////////////////////////////////
// Network CheckPoint Here //
/////////////////////////////////
api.checkPointer(network).checkPoint(nestedTester2 = new TestObserver<byte[]>() {
@Override public void onCompleted() {}
@Override public void onNext(byte[] bytes) {
assertTrue(bytes != null && bytes.length > 10);
}
});
}else if(inf.getRecordNum() == 999) {
network.halt();
}
}
});
network.start();
Publisher publisher = network.getPublisher();
List<String> hotStream = makeStream().collect(Collectors.toList());
int numRecords = 0;
boolean done = false;
while(true) {
for(String s : hotStream) {
publisher.onNext(s);
numRecords++;
if(numRecords == 1000) {
done = true; break;
}
}
if(done) break;
}
try {
network.lookup("r1").lookup("1").getLayerThread().join(5000);
}catch(Exception e) {
e.printStackTrace();
}
/////////////////////// Now test the checkpointed Network /////////////////////////
//////////////////////////////////////
// CheckPoint the Network //
//////////////////////////////////////
Network checkPointNetwork = null;
try {
checkPointNetwork = api.load(api.getLastCheckPointFileName());
assertNotNull(checkPointNetwork);
}catch(Exception e) {
e.printStackTrace();
fail();
}
TestObserver<Inference> tester2;
final Network cpn = checkPointNetwork;
checkPointNetwork.observe().subscribe(tester2 = new TestObserver<Inference>() {
@Override public void onCompleted() {}
@Override
public void onNext(Inference inf) {
// Assert that the records continue from where the checkpoint left off.
assertEquals(752, inf.getRecordNum());
cpn.halt();
}
});
checkPointNetwork.restart();
publisher = checkPointNetwork.getPublisher();
numRecords = 0;
done = false;
while(true) {
for(String s : hotStream) {
publisher.onNext(s);
numRecords++;
if(numRecords == 25) {
done = true; break;
}
}
if(done) break;
}
try {
checkPointNetwork.lookup("r1").lookup("1").getLayerThread().join(2000);
}catch(Exception e) {
e.printStackTrace();
}
checkObserver(tester);
checkObserver(tester2);
checkObserver(nestedTester2);
}
TestObserver<byte[]> nestedTester3;
@SuppressWarnings("unchecked")
@Test
public void testCheckPoint_SynchronousNetwork() {
Network network = getLoadedHotGymSynchronousNetwork();
PersistenceAPI api = Persistence.get();
Map<String, Map<String, Object>> fieldEncodingMap =
(Map<String, Map<String, Object>>)network.getParameters().get(KEY.FIELD_ENCODING_MAP);
MultiEncoder me = MultiEncoder.builder()
.name("")
.build()
.addMultipleEncoders(fieldEncodingMap);
network.lookup("r1").lookup("1").add(me);
// We just use this to parse the date field
DateEncoder dateEncoder = me.getEncoderOfType(FieldMetaType.DATETIME);
Map<String, Object> m = new HashMap<>();
List<String> l = makeStream().collect(Collectors.toList());
for(int j = 0;j < 50;j++) {
for(int i = 0;i < 20;i++) {
String[] sa = l.get(i).split("[\\s]*\\,[\\s]*");
m.put("timestamp", dateEncoder.parse(sa[0]));
m.put("consumption", Double.parseDouble(sa[1]));
network.computeImmediate(m);
if(j == 49 && i == 0) {
api.checkPointer(network).checkPoint(nestedTester3 = new TestObserver<byte[]>() {
@Override public void onCompleted() {}
@Override public void onNext(byte[] bytes) {
assertTrue(bytes != null && bytes.length > 10);
}
});
}
}
network.reset();
}
//////////////////////////////////////
// CheckPoint the Network //
//////////////////////////////////////
Network checkPointNetwork = null;
try {
checkPointNetwork = api.load(api.getLastCheckPointFileName());
assertNotNull(checkPointNetwork);
}catch(Exception e) {
fail();
}
int postCheckPointProcessCount = 0;
for(int j = 0;j < 1;j++) {
for(int i = 0;i < 20;i++) {
String[] sa = l.get(i).split("[\\s]*\\,[\\s]*");
m.put("timestamp", dateEncoder.parse(sa[0]));
m.put("consumption", Double.parseDouble(sa[1]));
Inference inf = checkPointNetwork.computeImmediate(m);
// Test that we being processing where the checkpoint left off...
assertTrue(inf.getRecordNum() == 981 + i);
++postCheckPointProcessCount;
}
checkPointNetwork.reset();
}
assertTrue(postCheckPointProcessCount > 19);
checkObserver(nestedTester3);
}
TestObserver<byte[]> nestedTester4;
@Test
public void testCheckPointHierarchies() {
Network network = getLoadedDayOfWeekStreamHierarchy();
PersistenceAPI api = Persistence.get();
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
int cycles = 0;
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(Inference i) {
if(cycles++ == 10) {
////////////////////////
// CheckPoint Here //
////////////////////////
api.checkPointer(network).checkPoint(nestedTester4 = new TestObserver<byte[]>() {
@Override public void onCompleted() {}
@Override public void onNext(byte[] bytes) {
assertEquals(10, i.getRecordNum());
assertTrue(bytes != null && bytes.length > 10);
}
});
}else if(cycles == 12) {
network.halt();
}
}
});
network.start();
try {
network.lookup("r2").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
//////////////////////////////////////
// CheckPoint the Network //
//////////////////////////////////////
Network cpn = null;
try {
cpn = api.load(api.getLastCheckPointFileName());
assertNotNull(cpn);
}catch(Exception e) {
e.printStackTrace();
fail();
}
TestObserver<Inference> tester2;
final Network checkPointNetwork = cpn;
checkPointNetwork.observe().subscribe(tester2 = new TestObserver<Inference>() {
int cycles = 0;
@Override public void onCompleted() {}
@Override public void onNext(Inference i) {
if(cycles++ == 10) {
assertEquals(21, i.getRecordNum());
checkPointNetwork.halt();
}
}
});
checkPointNetwork.start();
try {
checkPointNetwork.lookup("r2").lookup("1").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
checkObserver(tester);
checkObserver(tester2);
checkObserver(nestedTester4);
}
//////////////////////////////
// Utility Methods //
//////////////////////////////
private void deepCompare(Object obj1, Object obj2) {
try {
assertTrue(DeepEquals.deepEquals(obj1, obj2));
System.out.println("expected(" + obj1.getClass().getSimpleName() + "): " + obj1 + " actual: (" + obj1.getClass().getSimpleName() + "): " + obj2);
}catch(AssertionError ae) {
System.out.println("expected(" + obj1.getClass().getSimpleName() + "): " + obj1 + " but was: (" + obj1.getClass().getSimpleName() + "): " + obj2);
}
}
private Network getLoadedDayOfWeekStreamHierarchy() {
Parameters p = NetworkTestHarness.getParameters();
p = p.union(NetworkTestHarness.getDayDemoTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("dayOfWeek", CLAClassifier.class));
Layer<?> l2 = null;
Network network = Network.create("test network", p)
.add(Network.createRegion("r1")
.add(l2 = Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory()))
.add(Network.createLayer("3", p)
.add(new SpatialPooler())
.using(l2.getConnections()))
.connect("2", "3"))
.add(Network.createRegion("r2")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(Sensor.create(FileSensor::create, SensorParams.create(
Keys::path, "", ResourceLocator.path("days-of-week-stream.csv"))))))
.connect("r1", "r2");
return network;
}
private Network getLoadedDayOfWeekNetwork() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getDayDemoTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("dayOfWeek", CLAClassifier.class));
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("dayOfWeek")
.addHeader("number")
.addHeader("B").build() }));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
private Network getLoadedHotGymHierarchy() {
Parameters p = NetworkTestHarness.getParameters();
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
p.set(KEY.RANDOM, new MersenneTwister(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("consumption", CLAClassifier.class));
Network network = Network.create("test network", p)
.add(Network.createRegion("r1")
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory()))
.add(Network.createLayer("3", p)
.add(new SpatialPooler()))
.connect("2", "3"))
.add(Network.createRegion("r2")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(Sensor.create(FileSensor::create, SensorParams.create(
Keys::path, "", ResourceLocator.path("rec-center-hourly.csv"))))))
.connect("r1", "r2");
return network;
}
private Network getLoadedHotGymNetwork() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("consumption", CLAClassifier.class));
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("timestamp, consumption")
.addHeader("datetime, float")
.addHeader("B").build() }));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
private Network getLoadedHotGymSynchronousNetwork() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("consumption", CLAClassifier.class));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())));
return network;
}
private Network getLoadedHotGymNetwork_FileSensor() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("consumption", CLAClassifier.class));
Object[] n = { "some name", ResourceLocator.path("rec-center-hourly.csv") };
HTMSensor<File> sensor = (HTMSensor<File>)Sensor.create(
FileSensor::create, SensorParams.create(Keys::path, n));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
private Network createAndRunTestSpatialPoolerNetwork(int start, int runTo) {
Publisher manual = Publisher.builder()
.addHeader("dayOfWeek")
.addHeader("darr")
.addHeader("B").build();
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name", manual}));
Parameters p = NetworkTestHarness.getParameters().copy();
p.set(KEY.RANDOM, new MersenneTwister(42));
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("dayOfWeek", CLAClassifier.class));
Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap(
null, // map
8, // n
0, // w
0, // min
0, // max
0, // radius
0, // resolution
null, // periodic
null, // clip
Boolean.TRUE, // forced
"dayOfWeek", // fieldName
"darr", // fieldType (dense array as opposed to sparse array or "sarr")
"SDRPassThroughEncoder"); // encoderType
p.set(KEY.FIELD_ENCODING_MAP, settings);
Network network = Network.create("test network", p)
.add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.add(new SpatialPooler())
.add(sensor)));
network.start();
int[][] inputs = new int[7][8];
inputs[0] = new int[] { 1, 1, 0, 0, 0, 0, 0, 1 };
inputs[1] = new int[] { 1, 1, 1, 0, 0, 0, 0, 0 };
inputs[2] = new int[] { 0, 1, 1, 1, 0, 0, 0, 0 };
inputs[3] = new int[] { 0, 0, 1, 1, 1, 0, 0, 0 };
inputs[4] = new int[] { 0, 0, 0, 1, 1, 1, 0, 0 };
inputs[5] = new int[] { 0, 0, 0, 0, 1, 1, 1, 0 };
inputs[6] = new int[] { 0, 0, 0, 0, 0, 1, 1, 1 };
int[] expected0 = new int[] { 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0 };
int[] expected1 = new int[] { 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0 };
int[] expected2 = new int[] { 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 };
int[] expected3 = new int[] { 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0 };
int[] expected4 = new int[] { 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0 };
int[] expected5 = new int[] { 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0 };
int[] expected6 = new int[] { 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0 };
int[][] expecteds = new int[][] { expected0, expected1, expected2, expected3, expected4, expected5, expected6 };
TestObserver<Inference> tester;
network.observe().subscribe(tester = new TestObserver<Inference>() {
int test = 0;
@Override public void onCompleted() {}
@Override public void onError(Throwable e) {
super.onError(e);
e.printStackTrace();
}
@Override
public void onNext(Inference spatialPoolerOutput) {
// System.out.println("expected: " + Arrays.toString(expecteds[test]) + " -- " +
// "actual: " + Arrays.toString(spatialPoolerOutput.getSDR()));
assertTrue(Arrays.equals(expecteds[test++], spatialPoolerOutput.getSDR()));
}
});
// Now push some fake data through so that "onNext" is called above
for(int i = start;i <= runTo;i++) {
manual.onNext(Arrays.toString(inputs[i]));
}
manual.onComplete();
try {
network.lookup("r1").lookup("1").getLayerThread().join();
}catch(Exception e) { e.printStackTrace(); }
checkObserver(tester);
return network;
}
private Network createAndRunTestTemporalMemoryNetwork() {
Publisher manual = Publisher.builder()
.addHeader("dayOfWeek")
.addHeader("darr")
.addHeader("B").build();
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name", manual}));
Parameters p = getParameters();
p.set(KEY.INFERRED_FIELDS, getInferredFieldsMap("dayOfWeek", CLAClassifier.class));
Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap(
null, // map
20, // n
0, // w
0, // min
0, // max
0, // radius
0, // resolution
null, // periodic
null, // clip
Boolean.TRUE, // forced
"dayOfWeek", // fieldName
"darr", // fieldType (dense array as opposed to sparse array or "sarr")
"SDRPassThroughEncoder"); // encoderType
p.set(KEY.FIELD_ENCODING_MAP, settings);
Network network = Network.create("test network", p)
.add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.add(new TemporalMemory())
.add(sensor)));
network.start();
network.observe().subscribe(new Subscriber<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(Inference i) {}
});
final int[] input1 = new int[] { 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0 };
final int[] input2 = new int[] { 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0 };
final int[] input3 = new int[] { 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0 };
final int[] input4 = new int[] { 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0 };
final int[] input5 = new int[] { 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0 };
final int[] input6 = new int[] { 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 };
final int[] input7 = new int[] { 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0 };
final int[][] inputs = { input1, input2, input3, input4, input5, input6, input7 };
// Run until TemporalMemory is "warmed up".
int timeUntilStable = 602;
for(int j = 0;j < timeUntilStable;j++) {
for(int i = 0;i < inputs.length;i++) {
manual.onNext(Arrays.toString(inputs[i]));
}
}
manual.onComplete();
Layer<?> l = network.lookup("r1").lookup("1");
try {
l.getLayerThread().join();
System.out.println(Arrays.toString(SDR.asCellIndices(l.getConnections().getActiveCells())));
}catch(Exception e) {
assertEquals(InterruptedException.class, e.getClass());
}
return network;
}
public Stream<String> makeStream() {
return Stream.of(
"7/2/10 0:00,21.2",
"7/2/10 1:00,16.4",
"7/2/10 2:00,4.7",
"7/2/10 3:00,4.7",
"7/2/10 4:00,4.6",
"7/2/10 5:00,23.5",
"7/2/10 6:00,47.5",
"7/2/10 7:00,45.4",
"7/2/10 8:00,46.1",
"7/2/10 9:00,41.5",
"7/2/10 10:00,43.4",
"7/2/10 11:00,43.8",
"7/2/10 12:00,37.8",
"7/2/10 13:00,36.6",
"7/2/10 14:00,35.7",
"7/2/10 15:00,38.9",
"7/2/10 16:00,36.2",
"7/2/10 17:00,36.6",
"7/2/10 18:00,37.2",
"7/2/10 19:00,38.2",
"7/2/10 20:00,14.1");
}
private Parameters getTestEncoderParams() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"timestamp", "datetime", "DateEncoder");
fieldEncodings = setupMap(
fieldEncodings,
25,
3,
0, 0, 0, 0.1, null, null, null,
"consumption", "float", "RandomDistributedScalarEncoder");
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_PATTERN.getFieldName(), "MM/dd/YY HH:mm");
Parameters p = Parameters.getEncoderDefaultParameters();
p.set(KEY.FIELD_ENCODING_MAP, fieldEncodings);
return p;
}
private Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n, int w, double min, double max, double radius, double resolution, Boolean periodic,
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) {
if(map == null) {
map = new HashMap<String, Map<String, Object>>();
}
Map<String, Object> inner = null;
if((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<String, Object>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
inner.put("radius", radius);
inner.put("resolution", resolution);
if(periodic != null) inner.put("periodic", periodic);
if(clip != null) inner.put("clip", clip);
if(forced != null) inner.put("forced", forced);
if(fieldName != null) inner.put("fieldName", fieldName);
if(fieldType != null) inner.put("fieldType", fieldType);
if(encoderType != null) inner.put("encoderType", encoderType);
return map;
}
private BiFunction<Inference, Integer, Integer> createDayOfWeekInferencePrintout(boolean on) {
return new BiFunction<Inference, Integer, Integer>() {
private int cycles = 1;
public Integer apply(Inference inf, Integer cellsPerColumn) {
Classification<Object> result = inf.getClassification("dayOfWeek");
double day = mapToInputData((int[])inf.getLayerInput());
if(day == 1.0) {
if(on) {
System.out.println("\n=========================");
System.out.println("CYCLE: " + cycles);
}
cycles++;
}
if(on) {
System.out.println("RECORD_NUM: " + inf.getRecordNum());
System.out.println("ScalarEncoder Input = " + day);
System.out.println("ScalarEncoder Output = " + Arrays.toString(inf.getEncoding()));
System.out.println("SpatialPooler Output = " + Arrays.toString(inf.getFeedForwardActiveColumns()));
if(inf.getPreviousPredictiveCells() != null)
System.out.println("TemporalMemory Previous Prediction = " +
Arrays.toString(SDR.cellsAsColumnIndices(inf.getPreviousPredictiveCells(), cellsPerColumn)));
System.out.println("TemporalMemory Actives = " + Arrays.toString(SDR.asColumnIndices(inf.getSDR(), cellsPerColumn)));
System.out.print("CLAClassifier prediction = " +
stringValue((Double)result.getMostProbableValue(1)) + " --> " + ((Double)result.getMostProbableValue(1)));
System.out.println(" | CLAClassifier 1 step prob = " + Arrays.toString(result.getStats(1)) + "\n");
}
return cycles;
}
};
}
private double mapToInputData(int[] encoding) {
for(int i = 0;i < dayMap.length;i++) {
if(Arrays.equals(encoding, dayMap[i])) {
return i + 1;
}
}
return -1;
}
private int[][] dayMap = new int[][] {
new int[] { 1, 1, 0, 0, 0, 0, 0, 1 },
new int[] { 1, 1, 1, 0, 0, 0, 0, 0 },
new int[] { 0, 1, 1, 1, 0, 0, 0, 0 },
new int[] { 0, 0, 1, 1, 1, 0, 0, 0 },
new int[] { 0, 0, 0, 1, 1, 1, 0, 0 },
new int[] { 0, 0, 0, 0, 1, 1, 1, 0 },
new int[] { 0, 0, 0, 0, 0, 1, 1, 1 },
};
private String stringValue(Double valueIndex) {
String recordOut = "";
BigDecimal bdValue = new BigDecimal(valueIndex).setScale(3, RoundingMode.HALF_EVEN);
switch(bdValue.intValue()) {
case 1: recordOut = "Monday (1)";break;
case 2: recordOut = "Tuesday (2)";break;
case 3: recordOut = "Wednesday (3)";break;
case 4: recordOut = "Thursday (4)";break;
case 5: recordOut = "Friday (5)";break;
case 6: recordOut = "Saturday (6)";break;
case 0: recordOut = "Sunday (7)";break;
}
return recordOut;
}
public Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 8 });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 20 });
parameters.set(KEY.CELLS_PER_COLUMN, 6);
//SpatialPooler specific
parameters.set(KEY.POTENTIAL_RADIUS, 12);//3
parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(KEY.SYN_PERM_INACTIVE_DEC, 0.01);
parameters.set(KEY.SYN_PERM_ACTIVE_INC, 0.1);
parameters.set(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
parameters.set(KEY.SYN_PERM_CONNECTED, 0.1);
parameters.set(KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1);
parameters.set(KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1);
parameters.set(KEY.DUTY_CYCLE_PERIOD, 10);
parameters.set(KEY.MAX_BOOST, 10.0);
parameters.set(KEY.SEED, 42);
//Temporal Memory specific
parameters.set(KEY.INITIAL_PERMANENCE, 0.2);
parameters.set(KEY.CONNECTED_PERMANENCE, 0.8);
parameters.set(KEY.MIN_THRESHOLD, 5);
parameters.set(KEY.MAX_NEW_SYNAPSE_COUNT, 6);
parameters.set(KEY.PERMANENCE_INCREMENT, 0.05);
parameters.set(KEY.PERMANENCE_DECREMENT, 0.05);
parameters.set(KEY.ACTIVATION_THRESHOLD, 4);
parameters.set(KEY.RANDOM, new FastRandom(42));
return parameters;
}
}