/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import org.junit.Test;
import org.numenta.nupic.util.ArrayUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class SDRCategoryEncoderTest {
@Test
public void testSDRCategoryEncoder() {
System.out.println("Testing CategoryEncoder...");
//make sure we have>16 categories so that we have to grow our sdrs
String[] categories = {"ES", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8",
"S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16",
"S17", "S18", "S19", "GB", "US"};
int fieldWidth = 100;
int bitsOn = 10;
SDRCategoryEncoder sdrCategoryEncoder = SDRCategoryEncoder.builder()
.n(fieldWidth)
.w(bitsOn)
.categoryList(Arrays.asList(categories))
.name("foo")
.forced(true).build();
//internal check
assertEquals(sdrCategoryEncoder.getSDRs().size(), 23);
assertEquals(sdrCategoryEncoder.getSDRs().iterator().next().length, fieldWidth);
//ES
int[] es = sdrCategoryEncoder.encode("ES");
assertEquals(ArrayUtils.aggregateArray(es), bitsOn);
assertEquals(es.length, fieldWidth);
DecodeResult x = sdrCategoryEncoder.decode(es);
//assertIsInstance(x[0], dict) - NOT NEEDED IN JAVA
assertTrue("foo".equals(x.getDescriptions().iterator().next()));
assertTrue("ES".equals(x.getFields().get("foo").getDescription()));
List<Encoding> topDowns = sdrCategoryEncoder.topDownCompute(es);
Encoding topDown = topDowns.get(0);
assertEquals(topDown.getValue(), "ES");
assertEquals(topDown.getScalar(), 1);
assertEquals(ArrayUtils.aggregateArray(topDown.getEncoding()), bitsOn);
//Test topDown compute
for (String category : categories) {
int[] output = sdrCategoryEncoder.encode(category);
topDown = sdrCategoryEncoder.topDownCompute(output).get(0);
assertEquals(topDown.getValue(), category);
assertEquals(topDown.getScalar(), (int)sdrCategoryEncoder.getScalars(category).get(0));
int[] bucketIndices = sdrCategoryEncoder.getBucketIndices(category);
System.out.print("bucket index =>" + bucketIndices[0]);
}
//Unknown
int[] unknown = sdrCategoryEncoder.encode("ASDFLKJLK");
assertEquals(ArrayUtils.aggregateArray(unknown), bitsOn);
assertEquals(unknown.length, (fieldWidth));
x = sdrCategoryEncoder.decode(unknown);
assertEquals(x.getFields().get("foo").getDescription(), "<UNKNOWN>");
topDown = sdrCategoryEncoder.topDownCompute(unknown).get(0);
assertEquals(topDown.getValue(), "<UNKNOWN>");
assertEquals(topDown.getScalar(), 0);
//US
int[] us = sdrCategoryEncoder.encode("US");
assertEquals(ArrayUtils.aggregateArray(us), bitsOn);
assertEquals(us.length, (fieldWidth));
assertEquals(ArrayUtils.aggregateArray(us), bitsOn);
x = sdrCategoryEncoder.decode(us);
assertEquals(x.getFields().get("foo").getDescription(), "US");
topDown = sdrCategoryEncoder.topDownCompute(us).get(0);
assertEquals(topDown.getValue(), "US");
assertEquals(topDown.getScalar(), categories.length);
assertEquals(ArrayUtils.aggregateArray(topDown.getEncoding()), bitsOn);
// empty field
String[] emptyValues = {null, ""};
for (String emptyValue : emptyValues) {
int[] empty = sdrCategoryEncoder.encode(emptyValue);
assertEquals(ArrayUtils.aggregateArray(empty), 0);
assertEquals(empty.length, (fieldWidth));
}
//make sure it can still be decoded after a change
int bit = new Random().nextInt(sdrCategoryEncoder.getWidth() - 1);
us[bit] = 1 - us[bit];
x = sdrCategoryEncoder.decode(us);
assertEquals(x.getFields().get("foo").getDescription(), "US");
//add two reps together
int[] newrep = ArrayUtils.or(unknown, us);
x = sdrCategoryEncoder.decode(newrep);
String name = x.getFields().get("foo").getDescription();
if ("US <UNKNOWN>".equals(name) && "<UNKNOWN> US".equals(name)) {
String othercategory = name.replace("US", "");
othercategory = othercategory.replace("<UNKNOWN>", "");
othercategory = othercategory.replace(" ", "");
System.out.println(String.format("Got: %s instead of US/<UNKNOWN>", name));
System.out.println(String.format("US: %s", ArrayUtils.intArrayToString(us)));
System.out.println(String.format("unknown: %s", ArrayUtils.intArrayToString(unknown)));
System.out.println(String.format("Sum: %s", ArrayUtils.intArrayToString(newrep)));
System.out.println(String.format("%s: %s", othercategory, ArrayUtils.intArrayToString(
sdrCategoryEncoder.encode(othercategory))));
throw new RuntimeException("Decoding failure");
}
sdrCategoryEncoder = SDRCategoryEncoder.builder()
.n(fieldWidth)
.w(bitsOn)
.name("bar")
.forced(true).build();
es = sdrCategoryEncoder.encode("ES");
assertEquals(ArrayUtils.aggregateArray(es), bitsOn);
assertEquals(es.length, (fieldWidth));
x = sdrCategoryEncoder.decode(es);
assertEquals(x.getDescriptions().get(0), "bar");
assertEquals(x.getFields().get("bar").getDescription(), "ES");
us = sdrCategoryEncoder.encode("US");
assertEquals(ArrayUtils.aggregateArray(us), bitsOn);
assertEquals(us.length, (fieldWidth));
x = sdrCategoryEncoder.decode(us);
assertEquals(x.getDescriptions().get(0), "bar");
assertEquals(x.getFields().get("bar").getDescription(), "US");
x = sdrCategoryEncoder.decode(us);
assertEquals(x.getFields().get("bar").getDescription(), "US");
int[] es2 = sdrCategoryEncoder.encode("ES");
assertTrue(Arrays.equals(es, es2));
int[] us2 = sdrCategoryEncoder.encode("US");
assertTrue(Arrays.equals(us, us2));
//make sure it can still be decoded after a change
bit = new Random().nextInt(sdrCategoryEncoder.getWidth() - 1);
us[bit] = 1 - us[bit];
x = sdrCategoryEncoder.decode(us);
assertEquals(x.getFields().get("bar").getDescription(), "US");
// add two reps together
newrep = ArrayUtils.or(us, es);
x = sdrCategoryEncoder.decode(newrep);
name = x.getFields().get("bar").getDescription();
assertTrue("US ES".equals(name) || "ES US".equals(name));
// Catch duplicate categories
boolean caughtException = false;
ArrayList<String> newCategories = new ArrayList<>(Arrays.asList(categories));
newCategories.add("ES");
try {
sdrCategoryEncoder = SDRCategoryEncoder.builder()
.n(fieldWidth)
.w(bitsOn)
.categoryList(newCategories)
.name("foo")
.forced(true).build();
} catch (IllegalArgumentException e) {
caughtException = true;
}
if (!caughtException) {
throw new RuntimeException("Did not catch duplicate category in constructor");
}
}
@Test
public void testAutoGrow() {
//testing auto-grow
int fieldWidth = 100;
int bitsOn = 10;
SDRCategoryEncoder sdrCategoryEncoder = SDRCategoryEncoder.builder()
.n(fieldWidth)
.w(bitsOn)
.name("foo")
.forced(true).build();
int[] encoded = new int[fieldWidth];
Arrays.fill(encoded, 0);
assertEquals(sdrCategoryEncoder.topDownCompute(encoded).get(0).getValue(), "<UNKNOWN>");
sdrCategoryEncoder.encodeIntoArray("catA", encoded);
assertEquals(ArrayUtils.aggregateArray(encoded), bitsOn);
assertEquals(sdrCategoryEncoder.getScalars("catA").get(0), 1.0, 0.0);
int[] catA = new int[encoded.length];
System.arraycopy(encoded, 0, catA, 0, encoded.length);
sdrCategoryEncoder.encodeIntoArray("catB", encoded);
assertEquals(ArrayUtils.aggregateArray(encoded), bitsOn);
assertEquals(sdrCategoryEncoder.getScalars("catB").get(0), 2.0, 0.0);
int[] catB = new int[encoded.length];
System.arraycopy(encoded, 0, catB, 0, encoded.length);
assertEquals(sdrCategoryEncoder.topDownCompute(catA).get(0).getValue(), "catA");
assertEquals(sdrCategoryEncoder.topDownCompute(catB).get(0).getValue(), "catB");
// empty field
String[] emptyValues = {null, ""};
for (String emptyValue : emptyValues) {
sdrCategoryEncoder.encodeIntoArray(emptyValue, encoded);
assertEquals(ArrayUtils.aggregateArray(encoded), 0);
assertEquals(sdrCategoryEncoder.topDownCompute(encoded).get(0).getValue(), "<UNKNOWN>");
}
//Test Disabling Learning and autogrow
sdrCategoryEncoder.setLearning(false);
sdrCategoryEncoder.encodeIntoArray("catC", encoded);
assertEquals(ArrayUtils.aggregateArray(encoded), bitsOn);
assertEquals(sdrCategoryEncoder.getScalars("catC").get(0), 0, 0);
assertEquals(sdrCategoryEncoder.topDownCompute(encoded).get(0).getValue(), "<UNKNOWN>");
sdrCategoryEncoder.setLearning(true);
sdrCategoryEncoder.encodeIntoArray("catC", encoded);
assertEquals(ArrayUtils.aggregateArray(encoded), bitsOn);
assertEquals(sdrCategoryEncoder.getScalars("catC").get(0), 3, 0);
assertEquals(sdrCategoryEncoder.topDownCompute(encoded).get(0).getValue(), "catC");
}
}