/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.function.BiFunction;
import org.junit.Ignore;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.model.SDR;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.network.sensor.SensorParams.Keys;
import org.numenta.nupic.serialize.HTMObjectInput;
import org.numenta.nupic.util.FastRandom;
import rx.Observer;
public class PlaygroundTest {
private int[][] dayMap = new int[][] {
new int[] { 1, 1, 0, 0, 0, 0, 0, 1 },
new int[] { 1, 1, 1, 0, 0, 0, 0, 0 },
new int[] { 0, 1, 1, 1, 0, 0, 0, 0 },
new int[] { 0, 0, 1, 1, 1, 0, 0, 0 },
new int[] { 0, 0, 0, 1, 1, 1, 0, 0 },
new int[] { 0, 0, 0, 0, 1, 1, 1, 0 },
new int[] { 0, 0, 0, 0, 0, 1, 1, 1 },
};
private BiFunction<Inference, Integer, Integer> dayOfWeekPrintout = createDayOfWeekInferencePrintout();
@Ignore
public void testPlayground() {
final int NUM_CYCLES = 600;
final int INPUT_GROUP_COUNT = 7; // Days of Week
///////////////////////////////////////
// Load a Network //
///////////////////////////////////////
Network network = getLoadedDayOfWeekNetwork();
int cellsPerCol = (int)network.getParameters().get(KEY.CELLS_PER_COLUMN);
network.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@SuppressWarnings("unused")
@Override
public void onNext(Inference inf) {
/** see {@link #createDayOfWeekInferencePrintout()} */
int cycle = dayOfWeekPrintout.apply(inf, cellsPerCol);
}
});
Publisher pub = network.getPublisher();
network.start();
int cycleCount = 0;
for(;cycleCount < NUM_CYCLES;cycleCount++) {
for(double j = 0;j < INPUT_GROUP_COUNT;j++) {
pub.onNext("" + j);
}
network.reset();
}
// Test network output
try {
Region r1 = network.lookup("r1");
r1.lookup("1").getLayerThread().join(2000);
}catch(Exception e) {
e.printStackTrace();
}
}
///////////////////////////////////////////
// HELPER METHODS //
///////////////////////////////////////////
private Network getLoadedDayOfWeekNetwork() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getDayDemoTestEncoderParams());
p.set(KEY.RANDOM, new FastRandom(42));
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("dayOfWeek")
.addHeader("number")
.addHeader("B").build() }));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
private BiFunction<Inference, Integer, Integer> createDayOfWeekInferencePrintout() {
return new BiFunction<Inference, Integer, Integer>() {
private int cycles = 1;
public Integer apply(Inference inf, Integer cellsPerColumn) {
Classification<Object> result = inf.getClassification("dayOfWeek");
double day = mapToInputData((int[])inf.getLayerInput());
if(day == 1.0) {
System.out.println("\n=========================");
System.out.println("CYCLE: " + cycles);
cycles++;
}
System.out.println("RECORD_NUM: " + inf.getRecordNum());
System.out.println("ScalarEncoder Input = " + day);
System.out.println("ScalarEncoder Output = " + Arrays.toString(inf.getEncoding()));
System.out.println("SpatialPooler Output = " + Arrays.toString(inf.getFeedForwardActiveColumns()));
if(inf.getPreviousPredictiveCells() != null)
System.out.println("TemporalMemory Previous Prediction = " +
Arrays.toString(SDR.cellsAsColumnIndices(inf.getPreviousPredictiveCells(), cellsPerColumn)));
System.out.println("TemporalMemory Actives = " + Arrays.toString(SDR.asColumnIndices(inf.getSDR(), cellsPerColumn)));
System.out.print("CLAClassifier prediction = " +
stringValue((Double)result.getMostProbableValue(1)) + " --> " + ((Double)result.getMostProbableValue(1)));
System.out.println(" | CLAClassifier 1 step prob = " + Arrays.toString(result.getStats(1)) + "\n");
return cycles;
}
};
}
private double mapToInputData(int[] encoding) {
for(int i = 0;i < dayMap.length;i++) {
if(Arrays.equals(encoding, dayMap[i])) {
return i + 1;
}
}
return -1;
}
private String stringValue(Double valueIndex) {
String recordOut = "";
BigDecimal bdValue = new BigDecimal(valueIndex).setScale(3, RoundingMode.HALF_EVEN);
switch(bdValue.intValue()) {
case 1: recordOut = "Monday (1)";break;
case 2: recordOut = "Tuesday (2)";break;
case 3: recordOut = "Wednesday (3)";break;
case 4: recordOut = "Thursday (4)";break;
case 5: recordOut = "Friday (5)";break;
case 6: recordOut = "Saturday (6)";break;
case 0: recordOut = "Sunday (7)";break;
}
return recordOut;
}
@SuppressWarnings("unchecked")
public <T> T main(String[] args) throws Exception {
InputStream input = new FileInputStream(new File("myfile"));
//HTMObjectInput reader = Persistence.get().serializer().getObjectInput(input);
try (HTMObjectInput reader = Persistence.get().serializer().getObjectInput(input)) {
Class<?> aClass = null;//... // Persistable subclass
T t = (T) reader.readObject(aClass); // Where T is the Persistable subclass type (HTM.java object).
return t;
} catch(Exception e) {
throw e;
}
}
}