package org.numenta.nupic.network;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.junit.AfterClass;
import org.junit.Test;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.Anomaly.Mode;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.encoders.ScalarEncoder;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.SDR;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.FastRandom;
import org.numenta.nupic.util.UniversalRandom;
import com.cedarsoftware.util.DeepEquals;
import rx.Observer;
/**
* This file contains two test methods:
* <ul>
* <li>{@link #testSimpleLayer()} - which outputs data processed by the raw assembly of algorithms. and...
* <li>{@link #testNetworkAPI() - which outputs data processed the Network API (NAPI)
* </ul>
*
* As a sort of "vetting" of the NAPI, this illustrates that the resultant
* output is <b><i>exactly</i></b> the same proving that the NAPI does not
* impact the actual results. (at the time of this writing: 07/21/2016)
*
* @author cogmission
*/
public class NetworkConsistencyTest {
private static final int UPPER_BOUNDARY = 8;
private static final int RECORDS_PER_CYCLE = 7;
private static Set<SampleWeek> simpleSamples = new HashSet<>();
private static Set<SampleWeek> napiSamples = new HashSet<>();
private static boolean doPrintout = false;
private static final int SAMPLE_WEEK = new UniversalRandom(42).nextInt(125);
@AfterClass
public static void compare() {
System.out.println("USING SAMPLE #: " + SAMPLE_WEEK);
assertEquals(napiSamples.size(), simpleSamples.size());
if(doPrintout) {
System.out.println("\n--------------------------------");
for(Iterator<SampleWeek> it = simpleSamples.iterator(), it2 = napiSamples.iterator();it.hasNext() && it2.hasNext();) {
SampleWeek sw1 = it.next();
SampleWeek sw2 = it2.next();
System.out.println("Seq#: " + sw1.seqNum + " - " + sw2.seqNum);
System.out.println("Encoder: " + Arrays.toString(sw1.encoderOut) + " - " + Arrays.toString(sw2.encoderOut));
System.out.println("SP: " + Arrays.toString(sw1.spOut) + " - " + Arrays.toString(sw2.spOut));
System.out.println("TM (in): " + Arrays.toString(sw1.tmIn) + " - " + Arrays.toString(sw2.tmIn));
System.out.println("TM (pred. cols): " + Arrays.toString(sw1.tmPred) + " - " + Arrays.toString(sw2.tmPred));
System.out.println("TM (Active Cells): " + Arrays.toString(sw1.activeCells) + " - " + Arrays.toString(sw2.activeCells));
System.out.println("TM (Predictive Cells): " + Arrays.toString(sw1.predictiveCells) + " - " + Arrays.toString(sw2.predictiveCells));
System.out.println("Anomaly Score: " + sw1.score + " - " + sw2.score);
System.out.println("");
}
}
assertTrue(DeepEquals.deepEquals(simpleSamples, napiSamples));
}
////////////////////////////////////////////
// JUnit Test Methods //
////////////////////////////////////////////
/**
* Test the "raw" assembly of algorithms using a makeshift
* "faux" layer container.
*/
@Test
public void testSimpleLayer() {
SimpleLayer layer = new SimpleLayer();
for(int i = 0;i < 200;i++) {
for(int j = 1;j < UPPER_BOUNDARY;j++) {
layer.input(j, i, RECORDS_PER_CYCLE * i + j);
}
}
}
/**
* Test an assembly which is the same as the above using
* HTM.Java's Network API.
*/
@Test
public void testNetworkAPI() {
Network network = getNetwork();
network.start();
Publisher publisher = network.getPublisher();
for(int i = 0;i < 200;i++) {
for(int j = 1;j < UPPER_BOUNDARY;j++) {
publisher.onNext(String.valueOf(j));
}
}
publisher.onComplete();
try {
Region r = network.lookup("NAB Region");
r.lookup("NAB Layer").getLayerThread().join();
}catch(Exception e) {
e.printStackTrace();
}
}
/**
* Rudimentary test of the anomaly computation.
*/
@Test
public void testComputeAnomaly_4of6() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.PURE);
Anomaly anomalyComputer = Anomaly.create(params);
double score = anomalyComputer.compute(new int[] { 2, 5, 6, 11, 14, 18 }, new int[] { 2, 6, 11, 14 }, 0, 0);
assertEquals(0.3333333333333333, score, 0);
}
/**
* Rudimentary test of the anomaly computation.
*/
@Test
public void testComputeAnomaly_5of7() {
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.PURE);
Anomaly anomalyComputer = Anomaly.create(params);
double score = anomalyComputer.compute(new int[] { 0, 1, 8, 10, 13, 16, 18 }, new int[] { 0, 10, 13, 16, 18 }, 0, 0);
assertEquals(0.2857142857142857, score, 0);
}
//--------------------------------------------------------------------------------------
////////////////////////////////////////////
// Support Methods //
////////////////////////////////////////////
private Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 8 });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 20 });
parameters.set(KEY.CELLS_PER_COLUMN, 6);
//SpatialPooler specific
parameters.set(KEY.POTENTIAL_RADIUS, 12);//3
parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(KEY.SYN_PERM_INACTIVE_DEC, 0.0005);
parameters.set(KEY.SYN_PERM_ACTIVE_INC, 0.0015);
parameters.set(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
parameters.set(KEY.SYN_PERM_CONNECTED, 0.1);
parameters.set(KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1);
parameters.set(KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1);
parameters.set(KEY.DUTY_CYCLE_PERIOD, 10);
parameters.set(KEY.MAX_BOOST, 10.0);
parameters.set(KEY.SEED, 42);
//Temporal Memory specific
parameters.set(KEY.INITIAL_PERMANENCE, 0.2);
parameters.set(KEY.CONNECTED_PERMANENCE, 0.8);
parameters.set(KEY.MIN_THRESHOLD, 5);
parameters.set(KEY.MAX_NEW_SYNAPSE_COUNT, 6);
parameters.set(KEY.PERMANENCE_INCREMENT, 0.1);//0.05
parameters.set(KEY.PERMANENCE_DECREMENT, 0.1);//0.05
parameters.set(KEY.ACTIVATION_THRESHOLD, 4);
parameters.set(KEY.RANDOM, new FastRandom(42));
return parameters;
}
/**
* Parameters and meta information for the "dayOfWeek" encoder
* @return
*/
public Map<String, Map<String, Object>> getDayDemoFieldEncodingMap() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
8, // n
3, // w
1.0, 8.0, 1, 1, Boolean.TRUE, null, Boolean.TRUE,
"dayOfWeek", "number", "ScalarEncoder");
return fieldEncodings;
}
public static Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n, int w, double min, double max, double radius, double resolution, Boolean periodic,
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) {
if(map == null) {
map = new HashMap<String, Map<String, Object>>();
}
Map<String, Object> inner = null;
if((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<String, Object>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
inner.put("radius", radius);
inner.put("resolution", resolution);
if(periodic != null) inner.put("periodic", periodic);
if(clip != null) inner.put("clipInput", clip);
if(forced != null) inner.put("forced", forced);
if(fieldName != null) inner.put("fieldName", fieldName);
if(fieldType != null) inner.put("fieldType", fieldType);
if(encoderType != null) inner.put("encoderType", encoderType);
return map;
}
public String stringValue(Double valueIndex) {
String recordOut = "";
BigDecimal bdValue = new BigDecimal(valueIndex).setScale(3, RoundingMode.HALF_EVEN);
switch(bdValue.intValue()) {
case 1: recordOut = "Monday (1)";break;
case 2: recordOut = "Tuesday (2)";break;
case 3: recordOut = "Wednesday (3)";break;
case 4: recordOut = "Thursday (4)";break;
case 5: recordOut = "Friday (5)";break;
case 6: recordOut = "Saturday (6)";break;
case 7: recordOut = "Sunday (7)";break;
}
return recordOut;
}
private Network getNetwork() {
// Create Sensor publisher to push NAB input data to network
PublisherSupplier supplier = PublisherSupplier.builder()
.addHeader("dayOfWeek")
.addHeader("number")
.addHeader("B").build();
// Get updated model parameters
Parameters parameters = getParameters();
Map<String, Map<String, Object>> fieldEncodings = getDayDemoFieldEncodingMap();
parameters.set(KEY.FIELD_ENCODING_MAP, fieldEncodings);
int cellsPerColumn = (int)parameters.get(KEY.CELLS_PER_COLUMN);
Map<String, Object> params = new HashMap<>();
params.put(KEY_MODE, Mode.PURE);
// Create NAB Network
Network network = Network.create("NAB Network", parameters)
.add(Network.createRegion("NAB Region")
.add(Network.createLayer("NAB Layer", parameters)
.add(Anomaly.create(params))
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(Sensor.create(ObservableSensor::create,
SensorParams.create(SensorParams.Keys::obs, "Manual Input", supplier)))));
network.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
String layerInput = inf.getLayerInput().toString();
if(inf.getRecordNum() % RECORDS_PER_CYCLE == 0 && doPrintout) {
System.out.println("--------------------------------------------------------");
System.out.println("Iteration: " + (inf.getRecordNum() / 7));
}
if(doPrintout) System.out.println("===== " + layerInput + " - Sequence Num: " + (inf.getRecordNum() + 1) + " =====");
if(doPrintout) System.out.println("ScalarEncoder Input = " + layerInput);
if(doPrintout) System.out.println("ScalarEncoder Output = " + Arrays.toString(inf.getEncoding()));
if(doPrintout) System.out.println("SpatialPooler Output = " + Arrays.toString(inf.getFeedForwardActiveColumns()));
int[] predictedColumns = SDR.cellsAsColumnIndices(inf.getPredictiveCells(), cellsPerColumn); //Get the predicted column indexes
if(doPrintout) System.out.println("TemporalMemory Input = " + Arrays.toString(inf.getFeedForwardSparseActives()));
if(doPrintout) System.out.println("TemporalMemory Prediction = " + Arrays.toString(predictedColumns));
Set<Cell> actives = inf.getActiveCells();
int[] actCellIndices = SDR.asCellIndices(actives);
if(doPrintout) System.out.println("TemporalMemory Active Cells = " + Arrays.toString(actCellIndices));
Set<Cell> pred = inf.getPredictiveCells();
int[] predCellIndices = SDR.asCellIndices(pred);
if(doPrintout) System.out.println("TemporalMemory Predictive Cells = " + Arrays.toString(predCellIndices));
//Anomaly
double score = inf.getAnomalyScore();
if(doPrintout) System.out.println("Anomaly Score = " + score);
if(inf.getRecordNum() / 7 == SAMPLE_WEEK) {
napiSamples.add(new SampleWeek(inf.getRecordNum() + 1, inf.getEncoding(), inf.getFeedForwardActiveColumns(),
inf.getFeedForwardSparseActives(), predictedColumns, actCellIndices, predCellIndices, score));
}
}
});
return network;
}
//---------------------------------------------------------------------------------------
/////////////////////////////////////
// A Simple Layer Class //
/////////////////////////////////////
class SimpleLayer {
private Parameters params;
private Connections memory = new Connections();
private ScalarEncoder encoder;
private SpatialPooler spatialPooler;
private TemporalMemory temporalMemory;
private Anomaly anomaly;
private int columnCount;
private int cellsPerColumn;
private int[] predictedColumns;
private int[] prevPredictedColumns;
public SimpleLayer() {
params = getParameters();
ScalarEncoder.Builder dayBuilder =
ScalarEncoder.builder()
.n(8)
.w(3)
.radius(1.0)
.minVal(1.0)
.maxVal(8)
.periodic(true)
.forced(true)
.resolution(1);
encoder = dayBuilder.build();
spatialPooler = new SpatialPooler();
temporalMemory = new TemporalMemory();
Map<String, Object> anomalyParams = new HashMap<>();
anomalyParams.put(KEY_MODE, Mode.PURE);
anomaly = Anomaly.create(anomalyParams);
configure();
}
public SimpleLayer(Parameters p, ScalarEncoder e, SpatialPooler s, TemporalMemory t, Anomaly a) {
this.params = p;
this.encoder = e;
this.spatialPooler = s;
this.temporalMemory = t;
this.anomaly = a;
configure();
}
private void configure() {
columnCount = ((int[])params.get(KEY.COLUMN_DIMENSIONS))[0];
params.apply(memory);
spatialPooler.init(memory);
TemporalMemory.init(memory);
columnCount = memory.getPotentialPools().getMaxIndex() + 1; //If necessary, flatten multi-dimensional index
cellsPerColumn = memory.getCellsPerColumn();
}
public void input(double value , int recordNum, int seqNum) {
String recordOut = stringValue(value);
if(doPrintout && value == 1) {
System.out.println("--------------------------------------------------------");
System.out.println("Iteration: " + recordNum);
}
if(doPrintout) System.out.println("===== " + recordOut + " - Sequence Num: " + seqNum + " =====");
//Input through encoder
if(doPrintout) System.out.println("ScalarEncoder Input = " + value);
int[] encoding = encoder.encode(value);
if(doPrintout) System.out.println("ScalarEncoder Output = " + Arrays.toString(encoding));
//Input through spatial pooler
int[] output = new int[columnCount];
spatialPooler.compute(memory, encoding, output, true);
if(doPrintout) System.out.println("SpatialPooler Output = " + Arrays.toString(output));
//Input through temporal memory
int[] input = ArrayUtils.where(output, ArrayUtils.WHERE_1);
ComputeCycle cc = temporalMemory.compute(memory, input, true);
prevPredictedColumns = predictedColumns;
predictedColumns = SDR.cellsAsColumnIndices(cc.predictiveCells(), cellsPerColumn); //Get the predicted column indexes
if(doPrintout) System.out.println("TemporalMemory Input = " + Arrays.toString(input));
if(doPrintout) System.out.println("TemporalMemory Prediction = " + Arrays.toString(predictedColumns));
Set<Cell> actives = cc.activeCells();
int[] actCellIndices = SDR.asCellIndices(actives);
if(doPrintout) System.out.println("TemporalMemory Active Cells = " + Arrays.toString(actCellIndices));
Set<Cell> pred = cc.predictiveCells();
int[] predCellIndices = SDR.asCellIndices(pred);
if(doPrintout) System.out.println("TemporalMemory Predictive Cells = " + Arrays.toString(predCellIndices));
//Anomaly
double score = anomaly.compute(input, prevPredictedColumns, 0.0, 0);
if(doPrintout) System.out.println("Anomaly Score = " + score);
if(recordNum == SAMPLE_WEEK) {
simpleSamples.add(new SampleWeek(seqNum, encoding, output, input, predictedColumns, actCellIndices, predCellIndices, score));
}
}
}
class SampleWeek {
int seqNum;
int[] encoderOut, spOut, tmIn, tmPred, activeCells, predictiveCells;
double score;
public SampleWeek(int seq, int[] enc, int[] spo, int[] tmin, int[] tmpred, int[] actCells, int[] predCells, double sc) {
seqNum = seq;
encoderOut = enc;
spOut = spo;
tmIn = tmin;
tmPred = tmpred;
activeCells = actCells;
predictiveCells = predCells;
score = sc;
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(encoderOut);
long temp;
temp = Double.doubleToLongBits(score);
result = prime * result + (int)(temp ^ (temp >>> 32));
result = prime * result + seqNum;
result = prime * result + Arrays.hashCode(spOut);
result = prime * result + Arrays.hashCode(tmIn);
result = prime * result + Arrays.hashCode(tmPred);
result = prime * result + Arrays.hashCode(activeCells);
result = prime * result + Arrays.hashCode(predictiveCells);
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
SampleWeek other = (SampleWeek)obj;
if(!Arrays.equals(encoderOut, other.encoderOut))
return false;
if(Double.doubleToLongBits(score) != Double.doubleToLongBits(other.score))
return false;
if(seqNum != other.seqNum)
return false;
if(!Arrays.equals(spOut, other.spOut))
return false;
if(!Arrays.equals(tmIn, other.tmIn))
return false;
if(!Arrays.equals(tmPred, other.tmPred))
return false;
if(!Arrays.equals(activeCells, other.activeCells))
return false;
if(!Arrays.equals(predictiveCells, other.predictiveCells))
return false;
return true;
}
}
}