/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.Test;
import org.numenta.nupic.network.Persistence;
import org.numenta.nupic.network.PersistenceAPI;
import org.numenta.nupic.serialize.SerialConfig;
import gnu.trove.list.array.TIntArrayList;
public class CLAClassifierTest {
private CLAClassifier classifier;
public void setUp() {
classifier = new CLAClassifier();
}
/**
* Send same value 10 times and expect 100% likelihood for prediction.
*/
@Test
public void testSingleValue() {
setUp();
Classification<Double> retVal = null;
for(int recordNum = 0;recordNum < 10;recordNum++) {
retVal = compute(classifier, recordNum, new int[] { 1, 5 }, 0, 10);
}
checkValue(retVal, 0, 10., 1.);
}
/**
* Send same value 10 times and expect 100% likelihood for prediction
* using 0-step ahead prediction
*/
@Test
public void testSingleValue0Steps() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 0 }), 0.001, 0.3, 0);
// Enough times to perform Inference and learn associations
Classification<Double> retVal = null;
for(int recordNum = 0;recordNum < 10;recordNum++) {
retVal = compute(classifier, recordNum, new int[] { 1, 5 }, 0, 10);
}
assertEquals(10., retVal.getActualValue(0), .00001);
assertEquals(1., retVal.getStat(0, 0), .00001);
}
/**
* The meaning of this test is diminished in Java, because Java is already strongly typed and
* all expected value types are known and previously declared.
*/
@Test
public void testComputeResultTypes() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(0), 0.01);
}
@Test
public void testCompute1() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(0), 0.01);
}
@Test
public void testCompute2() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
Classification<Double> result = classifier.compute(1, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(5, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(4), 0.01);
}
@Test
public void testComputeComplex() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 41.7);
result = classifier.compute(recordNum, classification, new int[] { 0, 6, 9, 11 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 44.9);
result = classifier.compute(recordNum, classification, new int[] { 6, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 42.9);
result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(35.520000457763672, result.getActualValue(4), 0.00001);
assertEquals(42.020000457763672, result.getActualValue(5), 0.00001);
assertEquals(6, result.getStatCount(1));
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(0.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
assertEquals(0.0, result.getStat(1, 3), 0.00001);
assertEquals(0.12300123, result.getStat(1, 4), 0.00001);
assertEquals(0.87699877, result.getStat(1, 5), 0.00001);
}
@Test
public void testComputeWithMissingValue() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", null);
classification.put("actValue", null);
Classification<Double> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(null, result.getActualValue(0));
}
@Test
public void testComputeCategory() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", "D");
classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
Classification<String> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals("D", result.getActualValue(4));
}
@Test
public void testComputeCategory2() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", "D");
classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
classification.put("actValue", "E");
Classification<String> result = classifier.compute(0, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals("D", result.getActualValue(4));
}
@Test
public void testSerialization() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 41.7);
result = classifier.compute(recordNum, classification, new int[] { 0, 6, 9, 11 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 5);
classification.put("actValue", 44.9);
result = classifier.compute(recordNum, classification, new int[] { 6, 9 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 4);
classification.put("actValue", 42.9);
result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
// Configure serializer
SerialConfig config = new SerialConfig("testSerializeClassifier", SerialConfig.SERIAL_TEST_DIR);
PersistenceAPI api = Persistence.get(config);
// 1. serialize
byte[] data = api.write(classifier, "testSerializeClassifier");
// 2. deserialize
CLAClassifier serialized = api.read(data);
//Using the deserialized classifier, continue test
classification.put("bucketIdx", 4);
classification.put("actValue", 34.7);
result = serialized.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum += 1;
assertTrue(Arrays.equals(new int[] { 1 }, result.stepSet()));
assertEquals(35.520000457763672, result.getActualValue(4), 0.00001);
assertEquals(42.020000457763672, result.getActualValue(5), 0.00001);
assertEquals(6, result.getStatCount(1));
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(0.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
assertEquals(0.0, result.getStat(1, 3), 0.00001);
assertEquals(0.12300123, result.getStat(1, 4), 0.00001);
assertEquals(0.87699877, result.getStat(1, 5), 0.00001);
}
@Test
public void testOverlapPattern() {
setUp();
Classification<Double> result = compute(classifier, 0, new int[] { 1, 5 }, 9, 9);
result = compute(classifier, 1, new int[] { 1, 5 }, 9, 9);
result = compute(classifier, 1, new int[] { 1, 5 }, 9, 9);
result = compute(classifier, 2, new int[] { 3, 5 }, 2, 2);
// Since overlap - should be previous with 100%
checkValue(result, 9, 9., 1.0);
result = compute(classifier, 3, new int[] { 3, 5 }, 2, 2);
// Second example: now new value should be more probable than old
assertTrue(result.getStat(1, 2) > result.getStat(1, 9));
}
public void testScaling() {
setUp();
int recordNum = 0;
for(int i = 0;i < 100;i++, recordNum++) {
compute(classifier, recordNum, new int[] { 1 }, 5, 5);
}
for(int i = 0;i < 1000;i++, recordNum++) {
compute(classifier, recordNum, new int[] { 2 }, 9, 9);
}
for(int i = 0;i < 3;i++, recordNum++) {
compute(classifier, recordNum, new int[] { 1, 2 }, 6, 6);
}
}
@Test
public void testMultistepSingleValue() {
setUp();
classifier.steps = new TIntArrayList(new int[] { 1, 2 });
// Only should return one actual value bucket.
Classification<Double> result = null;
int recordNum = 0;
for(int i = 0;i < 10;i++, recordNum++) {
result = compute(classifier, recordNum, new int[] { 1, 5 }, 0, 10);
}
assertTrue(Arrays.equals(new Object[] { 10. }, result.getActualValues()));
// Should have a probability of 100% for that bucket.
assertTrue(Arrays.equals(new double[] { 1. }, result.getStats(1)));
assertTrue(Arrays.equals(new double[] { 1. }, result.getStats(2)));
}
@Test
public void testMultistepSimple() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1, 2 }), 0.001, 0.3, 0);
Classification<Double> result = null;
int recordNum = 0;
for(int i = 0;i < 100;i++, recordNum++) {
result = compute(classifier, recordNum, new int[] { i % 10 }, i % 10, (i % 10) * 10);
}
// Only should return one actual value bucket.
assertTrue(Arrays.equals(new Object[] { 0., 10., 20., 30., 40., 50., 60., 70., 80., 90. }, result.getActualValues()));
assertEquals(1.0, result.getStat(1, 0), 0.1);
for(int i = 1;i < 10;i++) {
assertEquals(0.0, result.getStat(1, i), 0.1);
}
assertEquals(1.0, result.getStat(2, 1), 0.1);
}
/**
* Test missing record support.
*
* Here, we intend the classifier to learn the associations:
* [1,3,5] => bucketIdx 1
* [2,4,6] => bucketIdx 2
* [7,8,9] => don't care
*
* If it doesn't pay attention to the recordNums in this test, it will learn the
* wrong associations.
*/
@Test
public void testMissingRecords() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 1 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 2);
classification.put("actValue", 2);
classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
// ----------------------------------------------------------------------------------
// At this point, we should have learned [1, 3, 5] => bucket 1
// [2, 4, 6] => bucket 2
classification.put("bucketIdx", 2);
classification.put("actValue", 2);
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(1.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
classification.put("bucketIdx", 1);
classification.put("actValue", 1);
result = classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(0.0, result.getStat(1, 1), 0.00001);
assertEquals(1.0, result.getStat(1, 2), 0.00001);
// ----------------------------------------------------------------------------------
// Feed in records that skip and make sure they don't mess up what we learned
//
// If we skip a record, the CLA should NOT learn that [2,4,6] from
// the previous learning associates with bucket 0
recordNum += 1; // <----- Does the skip
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(1.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
// If we skip a record, the CLA should NOT learn that [1,3,5] from
// the previous learning associates with bucket 0
recordNum += 1; // <----- Does the skip
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 2, 4, 6 }, true, true);
recordNum += 1;
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(0.0, result.getStat(1, 1), 0.00001);
assertEquals(1.0, result.getStat(1, 2), 0.00001);
// If we skip a record, the CLA should NOT learn that [2,4,6] from
// the previous learning associates with bucket 0
recordNum += 1; // <----- Does the skip
classification.put("bucketIdx", 0);
classification.put("actValue", 0);
result = classifier.compute(recordNum, classification, new int[] { 1, 3, 5 }, true, true);
recordNum += 1;
assertEquals(0.0, result.getStat(1, 0), 0.00001);
assertEquals(1.0, result.getStat(1, 1), 0.00001);
assertEquals(0.0, result.getStat(1, 2), 0.00001);
}
/**
* Test missing record edge TestCase
* Test an edge case in the classifier initialization when there is a missing
* record in the first n records, where n is the # of prediction steps.
*/
@Test
public void testMissingRecordInitialization() {
classifier = new CLAClassifier(new TIntArrayList(new int[] { 2 }), 0.1, 0.1, 0);
int recordNum = 0;
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", 0);
classification.put("actValue", 34.7);
classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
recordNum = 2;
Classification<Double> result = classifier.compute(recordNum, classification, new int[] { 1, 5, 9 }, true, true);
assertTrue(Arrays.equals(new int[] { 2 }, result.stepSet()));
assertEquals(1, result.getActualValueCount());
assertEquals(34.7, result.getActualValue(0), 0.01);
}
public void checkValue(Classification<?> retVal, int index, Object value, double probability) {
assertEquals(retVal.getActualValue(index), value);
assertEquals(probability, retVal.getStat(1, index), 0.01);
}
public <T> Classification<T> compute(CLAClassifier classifier, int recordNum, int[] pattern,
int bucket, Object value) {
Map<String, Object> classification = new LinkedHashMap<String, Object>();
classification.put("bucketIdx", bucket);
classification.put("actValue", value);
return classifier.compute(recordNum, classification, pattern, true, true);
}
}