package org.numenta.nupic.algorithms;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import org.junit.Test;
import org.numenta.nupic.network.Persistence;
import org.numenta.nupic.network.PersistenceAPI;
import org.numenta.nupic.serialize.SerialConfig;
import com.cedarsoftware.util.DeepEquals;
import gnu.trove.list.array.TIntArrayList;
public class SDRClassifierTest {
private SDRClassifier classifier;
/**
* Send same value 10 times and expect 100% likelihood for prediction.
*/
@Test
public void testSingleValue() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 1.0, 0.3, 0);
// Enough times to perform Inference and expect high likelihood for prediction.
Classification<Double> retVal = null;
for(int recordNum = 0; recordNum < 10; recordNum++) {
retVal = compute(classifier, recordNum, new int[] {1, 5}, 0, 10);
}
assertEquals(10.0, retVal.getActualValue(0), 0.0);
assertEquals(1.0, retVal.getStat(1, 0), 0.1);
}
@Test
/**
* Send same value 10 times and expect high likelihood for prediction
* using 0-step ahead prediction.
*/
public void testSingleValue0Steps() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 0 }), 1.0, 0.3, 0);
// Enough times to perform Inference and learn associations
Classification<Double> retVal = null;
for(int recordNum = 0; recordNum < 10; recordNum++) {
retVal = compute(classifier, recordNum, new int[] {1, 5}, 0, 10);
}
assertEquals(10, retVal.getActualValue(0), 0.0);
assertEquals(retVal.getStat(0, 0), 1.0, 0.01);
}
/**
* The meaning of this test is diminished in Java, because Java is already strongly typed and
* all expected value types are known and previously declared.
*/
@Test
public void testComputeResultTypes() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(0), 0.00001);
}
@Test
public void testComputeInferOrLearnOnly() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 1.0, 0.1, 0);
// learn only
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> retVal = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, false);
assertTrue(retVal == null);
// infer only
recordNum = 0;
classification.put("bucketIdx", 2);
classification.put("actValue", 14.2);
Classification<Double> retVal1 = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, false, true);
recordNum += 1;
classification.put("bucketIdx", 3);
classification.put("actValue", 20.5);
Classification<Double> retVal2 = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, false, true);
recordNum += 1;
// Since learning was turned off and pattern was the same, predDist should
// be same for previous two computes.
assertArrayEquals(retVal1.getStats(1), retVal2.getStats(1), 0);
// return null when learn and infer are both false
classification.put("bucketIdx", 2);
classification.put("actValue", 14.2);
Classification<Double> retVal3 = classifier.compute(recordNum, classification, new int[] { 1, 2 }, false, false);
assertTrue(retVal3 == null);
}
@Test
public void testCompute1() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(0), 0.00001);
}
@Test
public void testCompute2() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(34.7, result.getActualValue(4), 0.00001);
}
@Test
public void testComputeComplex() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 1.0, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 41.7);
classifier.compute(recordNum, classification, new int[] { 0, 6, 9, 11 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 44.9);
classifier.compute(recordNum, classification, new int[] { 6, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 42.9);
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(35.520000457763672, result.getActualValue(4), 0.00001);
assertEquals(42.020000457763672, result.getActualValue(5), 0.00001);
assertEquals(6, result.getStatCount(1));
assertEquals(0.034234, result.getStat(1, 0), 0.00001);
assertEquals(0.034234, result.getStat(1, 1), 0.00001);
assertEquals(0.034234, result.getStat(1, 2), 0.00001);
assertEquals(0.034234, result.getStat(1, 3), 0.00001);
assertEquals(0.093058, result.getStat(1, 4), 0.00001);
assertEquals(0.770004, result.getStat(1, 5), 0.00001);
}
@Test
public void testComputeWithMissingValue() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", null);
classification.put("actValue", null);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(null, result.getActualValue(0));
}
@Test
public void testComputeCategory() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", "D");
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", "D");
Classification<String> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals("D", result.getActualValue(4));
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", null);
Classification<String> predictResult = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
for(int i = 0; i < predictResult.getActualValueCount(); i++) {
assertTrue(predictResult.getActualValue(i) == null ||
predictResult.getActualValue(i).getClass().equals(String.class));
}
}
@Test
public void testComputeCategory2() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", "D");
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", "E");
Classification<String> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals("D", result.getActualValue(4));
}
@Test
public void testOverlapPattern() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 10.0, 0.3, 0);
compute(classifier, 0, new int[] { 1, 5 }, 9, 9);
compute(classifier, 1, new int[] { 1, 5 }, 9, 9);
Classification<Double> retVal = compute(classifier, 2, new int[] { 3, 5 }, 2, 2.0);
// Since overlap - should be previous with high likelihood
assertEquals(9.0, retVal.getActualValue(9), 0.0);
assertTrue(retVal.getStat(1, 9) > 0.9);
retVal = compute(classifier, 3, new int[] { 3, 5 }, 2, 2);
// Second example: now new value should be more probable than old
assertTrue(retVal.getStat(1, 2) > retVal.getStat(1, 9));
}
@Test
public void testMultistepSingleValue() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1, 2 }), 0.001, 0.3, 0);
Classification<Double> retVal = null;
for(int recordNum = 0; recordNum < 10; recordNum++) {
retVal = compute(classifier, recordNum, new int[] { 1, 5 }, 0, 10);
}
//Only should return one actual value bucket.
assertEquals(10, retVal.getActualValue(retVal.getMostProbableBucketIndex(1)), 0.0);
assertTrue(Arrays.equals(new Double[] { 10.0 }, retVal.getActualValues()));
//Should have a probability of 100% for that bucket.
assertEquals(1.0, retVal.getStat(1, 0), 0.0);
assertEquals(1.0, retVal.getStat(2, 0), 0.0);
}
@Test
public void testMultiStepSimple() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1, 2 }), 10.0, 0.3, 0);
Classification<Double> retVal = null;
int recordNum = 0;
for(int i = 0; i < 100; i++) {
retVal = compute(classifier, recordNum, new int[] { i % 10 }, i % 10, (i % 10)*10);
recordNum += 1;
}
//Only should return one actual value bucket.
assertArrayEquals(new Double[] { 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0 },
retVal.getActualValues());
assertTrue(retVal.getStat(1, 0) > 0.99);
for(int i = 1; i < 10; i++) {
assertTrue(retVal.getStat(1, i) < 0.01);
}
assertTrue(retVal.getStat(2, 1) > 0.99);
for(int i = 0; i < 10; i++) {
if(i == 1) continue;
assertTrue(retVal.getStat(2, i) < 0.01);
}
}
@Test
/**
* Test missing record support.
*
* Here, we intend the classifier to learn the associations:
* [1,3,5] => bucketIdx 1
* [2,4,6] => bucketIdx 2
* [7,8,9] => don't care
*
* If it doesn't pay attention to the recordNums in this test, it
* will learn the wrong associations.
*/
public void testMissingRecords() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 1.0, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 2);
classification.put("actValue", 2);
classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
// -----------------------------------------------------------------------
// At this point, we should have learned [1,3,5] => bucket 1
// [2,4,6] => bucket 2
classification.put("bucketIdx", 2);
classification.put("actValue", 2);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertTrue(result.getStat(1, 0) < 0.1);
assertTrue(result.getStat(1, 1) > 0.9);
assertTrue(result.getStat(1, 2) < 0.1);
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
result = classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
assertTrue(result.getStat(1, 0) < 0.1);
assertTrue(result.getStat(1, 1) < 0.1);
assertTrue(result.getStat(1, 2) > 0.9);
// -----------------------------------------------------------------------
// Feed in records that skip and make sure they don't mess up
// what we learned.
// If we skip a record, the SDRClassifier should NOT learn that [2,4,6]
// from the previous learn associates with bucket 0
recordNum += 1;
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertTrue(result.getStat(1, 0) < 0.1);
assertTrue(result.getStat(1, 1) > 0.9);
assertTrue(result.getStat(1, 2) < 0.1);
// If we skip a record, the SDRClassifier should NOT learn that [1,3,5]
// from the previous learn associates with bucket 0
recordNum += 1;
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
assertTrue(result.getStat(1, 0) < 0.1);
assertTrue(result.getStat(1, 1) < 0.1);
assertTrue(result.getStat(1, 2) > 0.9);
// If we skip a record, the SDRClassifier should NOT learn that [2,4,6]
// from the previous learn associates with bucket 0
recordNum += 1;
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertTrue(result.getStat(1, 0) < 0.1);
assertTrue(result.getStat(1, 1) > 0.9);
assertTrue(result.getStat(1, 2) < 0.1);
}
@Test
/**
* Test missing record edge TestCase
* Test an edge case in the classifier initialization when there is a
* missing record in the first n records, where n is the # of prediction steps.
*/
public void testMissingRecordInitialization() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 2 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new HashMap<String, Object>();
classification.put("bucketIdx", 0);
classification.put("actValue", 34.7);
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum = 2;
classification.put("bucketIdx", 0);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 2 }, result.stepSet()));
assertEquals(1, result.getStepCount());
assertEquals(34.7, result.getActualValue(0), 0.01);
}
@Test
/**
* Test the distribution of predictions.
*
* Here, we intend the classifier to learn the associations:
* [1,3,5] => bucketIdx 0 (30%)
* => bucketIdx 1 (30%)
* => bucketIdx 2 (40%)
*
* [2,4,6] => bucketIdx 1 (50%)
* => bucketIdx 3 (50%)
*
* The classifier should get the distribution almost right given
* enough repetitions and a small learning rate.
*/
public void testPredictionDistribution() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 0 }), 0.001, 0.1, 0);
int[] SDR1 = {1, 3, 5};
int[] SDR2 = {2, 4, 5};
int recordNum = 0;
int bucketIdx = 0;
Map<String, Object> classification = new HashMap<String, Object>();
Random random = new Random(42);
for(int i = 0; i < 5000; i++) {
double randomNumber = random.nextDouble();
if(randomNumber < 0.3) {
bucketIdx = 0;
}
else if(randomNumber < 0.6) {
bucketIdx = 1;
}
else {
bucketIdx = 2;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR1, true, false);
recordNum += 1;
randomNumber = random.nextDouble();
if(randomNumber < 0.5) {
bucketIdx = 1;
}
else {
bucketIdx = 3;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR2, true, false);
recordNum += 1;
}
Classification<Double> result1 = classifier.compute(2, null, SDR1, false, true);
assertEquals(0.3, result1.getStat(0, 0), 0.1);
assertEquals(0.3, result1.getStat(0, 1), 0.1);
assertEquals(0.4, result1.getStat(0, 2), 0.1);
Classification<Double> result2 = classifier.compute(2, null, SDR2, false, true);
assertEquals(0.5, result2.getStat(0, 1), 0.1);
assertEquals(0.5, result2.getStat(0, 3), 0.1);
}
@Test
/**
* Test the distribution of predictions with overlapping input SDRs.
*
* Here, we intend the classifier to learn the associations:
* SDR1 => bucketIdx 0 (30%)
* => bucketIdx 1 (30%)
* => bucketIdx 2 (40%)
*
* SDR2 => bucketIdx 1 (50%)
* => bucketIdx 3 (50%)
*
* SDR1 and SDR2 have 10% overlap (2 bits out of 20)
* The classifier should get the distribution almost right despite the overlap
*/
public void testPredictionDistributionOverlap() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 0 }), 0.0005, 0.1, 0);
int recordNum = 0;
// Generate 2 SDRs with 2 shared bits
int[] SDR1 = new int[20];
int[] SDR2 = new int[20];
for(int i = 0; i < 40; i++) {
if(i % 2 == 0)
SDR1[i/2] = i; // SDR1 = {0, 2, 4, 6, ... , 38}
else
SDR2[(i - 1)/2] = i; // SDR2 = {1, 3, 5, 7, ... , 39}
}
SDR2[3] = SDR1[5];
SDR2[5] = SDR1[11];
int bucketIdx = 0;
Map<String, Object> classification = new HashMap<String, Object>();
Random random = new Random(42);
for(int i = 0; i < 5000; i++) {
double randomNumber = random.nextDouble();
if (randomNumber < 0.3) {
bucketIdx = 0;
} else if (randomNumber < 0.6) {
bucketIdx = 1;
} else {
bucketIdx = 2;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR1, true, false);
recordNum += 1;
randomNumber = random.nextDouble();
if(randomNumber < 0.5) {
bucketIdx = 1;
}
else {
bucketIdx = 3;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR2, true, false);
recordNum += 1;
}
Classification<Double> result1 = classifier.compute(2, null, SDR1, false, true);
assertEquals(0.3, result1.getStat(0, 0), 0.1);
assertEquals(0.3, result1.getStat(0, 1), 0.1);
assertEquals(0.4, result1.getStat(0, 2), 0.1);
Classification<Double> result2 = classifier.compute(2, null, SDR2, false, true);
assertEquals(0.5, result2.getStat(0, 1), 0.1);
assertEquals(0.5, result2.getStat(0, 3), 0.1);
}
@Test
/**
* Test continuous learning
*
* First, we intend the classifier to learn the associations:
* SDR1 => bucketIdx 0 (30%)
* => bucketIdx 1 (30%)
* => bucketIdx 2 (40%)
*
* SDR2 => bucketIdx 1 (50%)
* => bucketIdx 3 (50%)
*
* After 20,000 iterations, we change the association to
* SDR1 => bucketIdx 0 (30%)
* => bucketIdx 1 (20%)
* => bucketIdx 2 (40%)
*
* No further training for SDR2
*
* The classifier should adapt continuously and learn new associations
* for SDR1, but at the same time remember the old association for SDR2.
*/
public void testPredictionDistributionContinuousLearning() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 0 }), 0.001, 0.1, 0);
int recordNum = 0;
int[] SDR1 = {1, 3, 5};
int[] SDR2 = {2, 4, 6};
int bucketIdx = 0;
Map<String, Object> classification = new HashMap<String, Object>();
Random random = new Random(42);
double randomNumber = 0;
for(int i = 0; i < 10000; i++) {
randomNumber = random.nextDouble();
if (randomNumber < 0.3) {
bucketIdx = 0;
} else if (randomNumber < 0.6) {
bucketIdx = 1;
} else {
bucketIdx = 2;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR1, true, false);
recordNum += 1;
randomNumber = random.nextDouble();
if(randomNumber < 0.5) {
bucketIdx = 1;
}
else {
bucketIdx = 3;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR2, true, true);
recordNum += 1;
}
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
Classification<Double> result1 = classifier.compute(2, classification, SDR1, false, true);
assertEquals(0.3, result1.getStat(0, 0), 0.1);
assertEquals(0.3, result1.getStat(0, 1), 0.1);
assertEquals(0.4, result1.getStat(0, 2), 0.1);
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
Classification<Double> result2 = classifier.compute(2, classification, SDR2, false, true);
assertEquals(0.5, result2.getStat(0, 1), 0.1);
assertEquals(0.5, result2.getStat(0, 3), 0.1);
for(int i = 0; i < 20000; i++) {
randomNumber = random.nextDouble();
if (randomNumber < 0.3) {
bucketIdx = 0;
} else if (randomNumber < 0.6) {
bucketIdx = 1;
} else {
bucketIdx = 3;
}
classification.put("bucketIdx", bucketIdx);
classification.put("actValue", bucketIdx);
classifier.compute(recordNum, classification, SDR1, true, false);
recordNum += 1;
}
Classification<Double> result1new = classifier.compute(2, null, SDR1, false, true);
assertEquals(0.3, result1new.getStat(0, 0), 0.1);
assertEquals(0.3, result1new.getStat(0, 1), 0.1);
assertEquals(0.4, result1new.getStat(0, 3), 0.1);
Classification<Double> result2new = classifier.compute(2, null, SDR2, false, true);
assertTrue(Arrays.equals(result2new.getStats(0), result2.getStats(0)));
}
@Test
/**
* Test multi-step predictions
*
* We train the 0-step and the 1-step classifiers simultaneously on
* data stream
* (SDR1, bucketIdx0)
* (SDR2, bucketIdx1)
* (SDR1, bucketIdx0)
* (SDR2, bucketIdx1)
* ...
*
* We intend the 0-step classifier to learn the associations:
* SDR1 => bucketIdx 0
* SDR2 => bucketIdx 1
*
* and the 1-step classifier to learn the associations
* SDR1 => bucketIdx 1
* SDR2 => bucketIdx 0
*/
public void testMultiStepPredictions() {
classifier = new SDRClassifier(new TIntArrayList(new int[] { 0 }), 1.0, 0.1, 0);
int recordNum = 0;
int[] SDR1 = {1, 3, 5};
int[] SDR2 = {2, 4, 6};
Map<String, Object> classification = new HashMap<String, Object>();
for(int i = 0; i < 100; i++) {
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
classifier.compute(recordNum, classification, SDR1, true, false);
recordNum += 1;
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
classifier.compute(recordNum, classification, SDR2, true, true);
recordNum += 1;
}
Classification<Double> result1 = classifier.compute(recordNum, null, SDR1, false, true);
Classification<Double> result2 = classifier.compute(recordNum, null, SDR2, false, true);
assertEquals(1.0, result1.getStat(0, 0), 0.1);
assertEquals(0.0, result1.getStat(0, 1), 0.1);
assertEquals(0.0, result2.getStat(0, 0), 0.1);
assertEquals(1.0, result2.getStat(0, 1), 0.1);
}
@Test
public void testWriteRead() {
// Create two classifiers, so one can be serialized and tested against the other
SDRClassifier c1 = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
SDRClassifier c2 = new SDRClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
// Create input vectors A, B, and C (int[] of active bit indices from below)
int[] inputA = new int[] { 1, 5, 9 };
int[] inputB = new int[] { 2, 4, 6 };
int[] inputC = new int[] { 3, 5, 7 };
// Create classification Map
Map<String, Object> classification = new HashMap<>();
// Have both classifiers process input A
classification.put("bucketIdx", 0);
classification.put("actValue", "A");
Classification<String> result1 = c1.compute(0, classification, inputA, true, true);
Classification<String> result2 = c2.compute(0, classification, inputA, true, true);
// Compare results, should be equal
assertArrayEquals(result1.getStats(1), result2.getStats(1), 0.0);
assertArrayEquals(result1.getActualValues(), result2.getActualValues());
// Serialize classifier #2
SerialConfig config = new SerialConfig("testSerializeSDRClassifier", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
byte[] data = api.write(c2);
// Deserialize classifier #2 into new reference.
SDRClassifier reifiedC2 = api.read(data);
// Make sure it isn't null...
assertNotNull(reifiedC2);
// Make sure pre- and post-serialization classifiers are identical
assertTrue(DeepEquals.deepEquals(c2, reifiedC2));
// Have the non-serialized classifier and the deserialized classifier
// process input B.
classification.put("bucketIdx", 1);
classification.put("actValue", "B");
result1 = c1.compute(0, classification, inputB, true, true);
result2 = reifiedC2.compute(0, classification, inputB, true, true);
// Compare the results. Make sure (de)serialization hasn't
// messed up classifier #2.
assertArrayEquals(result1.getStats(1), result2.getStats(1), 0.0);
assertArrayEquals(result1.getActualValues(), result2.getActualValues());
// Process input C - just to be safe...
classification.put("bucketIdx", 2);
classification.put("actValue", "C");
result1 = c1.compute(0, classification, inputC, true, true);
result2 = reifiedC2.compute(0, classification, inputC, true, true);
// Compare results
assertArrayEquals(result1.getStats(1), result2.getStats(1), 0.0);
assertArrayEquals(result1.getActualValues(), result2.getActualValues());
}
public <T> Classification<T> compute(SDRClassifier classifier, int recordNum, int[] pattern,
int bucket, Object value) {
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", bucket);
classification.put("actValue", value);
return classifier.compute(recordNum, classification, pattern, true, true);
}
}