/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import org.junit.Test;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.MinMax;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class CategoryEncoderTest {
private static final Logger LOGGER = LoggerFactory.getLogger(CategoryEncoderTest.class);
private CategoryEncoder ce;
private CategoryEncoder.Builder builder;
private void setUp() {
builder = ((CategoryEncoder.Builder)CategoryEncoder.builder())
.w(3)
.radius(0)
.minVal(0.0)
.maxVal(8.0)
.periodic(false)
.forced(true);
}
private void initCE() {
ce = builder.build();
}
@Test
public void testCategoryEncoder() {
String[] categories = new String[] { "ES", "GB", "US" };
setUp();
builder.radius(1);
builder.categoryList(Arrays.<String>asList(categories));
initCE();
LOGGER.info("Testing CategoryEncoder...");
// forced: is not recommended, but is used here for readability. see scalar.py
int[] output = ce.encode("US");
assertTrue(Arrays.equals(new int[] { 0,0,0,0,0,0,0,0,0,1,1,1 }, output));
// Test reverse lookup
DecodeResult decoded = ce.decode(output, "");
assertEquals(decoded.getFields().size(), 1, 0);
List<RangeList> rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
MinMax minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(minMax.min() == 3 && minMax.max() == 3);
LOGGER.info("decodedToStr of " + minMax + "=>" + ce.decodedToStr(decoded));
// Test topdown compute
for(String v : categories) {
output = ce.encode(v);
Encoding topDown = ce.topDownCompute(output).get(0);
assertEquals(v, topDown.getValue());
assertEquals((int)ce.getScalars(v).get(0), (int)topDown.getScalar().doubleValue());
int[] bucketIndices = ce.getBucketIndices(v);
LOGGER.info("bucket index => " + bucketIndices[0]);
topDown = ce.getBucketInfo(bucketIndices).get(0);
assertEquals(v, topDown.getValue());
assertEquals((int)ce.getScalars(v).get(0), (int)topDown.getScalar().doubleValue());
assertTrue(Arrays.equals(topDown.getEncoding(), output));
assertEquals(topDown.getValue(), ce.getBucketValues(String.class).get(bucketIndices[0]));
}
//-------------
// unknown category
output = ce.encode("NA");
assertTrue(Arrays.equals(new int[] { 1,1,1,0,0,0,0,0,0,0,0,0 }, output));
// Test reverse lookup
decoded = ce.decode(output, "");
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(minMax.min() == 0 && minMax.max() == 0);
LOGGER.info("decodedToStr of " + minMax + "=>" + ce.decodedToStr(decoded));
Encoding topDown = ce.topDownCompute(output).get(0);
assertEquals(topDown.getValue(), "<UNKNOWN>");
assertEquals(topDown.getScalar(), 0);
//--------------
// ES
output = ce.encode("ES");
assertTrue(Arrays.equals( new int[] {0,0,0,1,1,1,0,0,0,0,0,0 }, output));
// MISSING VALUE
int[] outputForMissing = ce.encode((String)null);
assertTrue(Arrays.equals( new int[] {0,0,0,0,0,0,0,0,0,0,0,0 }, outputForMissing));
// Test reverse lookup
decoded = ce.decode(output, "");
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(minMax.min() == 1 && minMax.max() == 1);
LOGGER.info("decodedToStr of " + minMax + "=>" + ce.decodedToStr(decoded));
// Test topdown compute
topDown = ce.topDownCompute(output).get(0);
assertEquals(topDown.getValue(), "ES");
assertEquals(topDown.getScalar(), (int)ce.getScalars("ES").get(0));
//----------------
// Multiple categories
Arrays.fill(output, 1);
// Test reverse lookup
decoded = ce.decode(output, "");
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
minMax = rangeList.get(0).getRanges().get(0);
assertTrue(minMax.min() != minMax.max());
assertTrue(minMax.min() == 0 && minMax.max() == 3);
LOGGER.info("decodedToStr of " + minMax + "=>" + ce.decodedToStr(decoded));
//----------------
// Test with width = 1
categories = new String[] { "cat1", "cat2", "cat3", "cat4", "cat5" };
setUp();
builder.radius(1);
builder.categoryList(Arrays.<String>asList(categories));
initCE();
for(String cat : categories) {
output = ce.encode(cat);
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug(cat + "->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
}
//==================
// Test with width = 9, removing some bits in the encoded output
categories = new String[9];
for(int i = 0;i < 9;i++) categories[i] = String.format("cat%d", i + 1);
//forced: is not recommended, but is used here for readability.
setUp();
builder.radius(1);
builder.w(9);
builder.forced(true);
builder.categoryList(Arrays.<String>asList(categories));
initCE();
for(String cat : categories) {
output = ce.encode(cat);
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug(cat + "->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) {
return i == 1;
}
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
// Get rid of 1 bit on the left
int[] outputNZs = ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
});
// int[] outputPreserve = Arrays.copyOf(output, output.length);
output[outputNZs[0]] = 0;
// LOGGER.info("output = " + Arrays.toString(outputPreserve));
// LOGGER.info("outputNZs = " + Arrays.toString(outputNZs));
// LOGGER.info("outputDelta = " + Arrays.toString(output));
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug("missing 1 bit on left: ->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
// Get rid of 1 bit on the right
output[outputNZs[0]] = 1;
output[outputNZs[outputNZs.length - 1]] = 0;
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug("missing 1 bit on right: ->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
// Get rid of 4 bits on the left
Arrays.fill(output, 0);
int[] indexes = ArrayUtils.range(outputNZs[outputNZs.length - 5], outputNZs[outputNZs.length - 1] + 1);
for(int i = 0;i < indexes.length;i++) output[indexes[i]] = 1;
LOGGER.info(Arrays.toString(output));
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug("missing 4 bits on left: ->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
// Get rid of 4 bits on the right
Arrays.fill(output, 0);
indexes = ArrayUtils.range(outputNZs[0], outputNZs[5]);
for(int i = 0;i < indexes.length;i++) output[indexes[i]] = 1;
LOGGER.info(Arrays.toString(output));
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug("missing 4 bits on left: ->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertEquals(topDown.getValue(), cat);
assertEquals(topDown.getScalar(), (int)ce.getScalars(cat).get(0));
}
int[] output1 = ce.encode("cat1");
int[] output2 = ce.encode("cat9");
output = ArrayUtils.or(output1, output2);
topDown = ce.topDownCompute(output).get(0);
LOGGER.debug("cat1 + cat9 ->" + Arrays.toString(output) +
" " + ArrayUtils.where(output, new Condition.Adapter<Integer>() {
public boolean eval(int i) { return i == 1; }
}));
LOGGER.debug(" scalarTopDown: " + ce.topDownCompute(output));
LOGGER.debug(" topDown " + topDown);
assertTrue(topDown.getScalar().equals((int)ce.getScalars("cat1").get(0)) ||
topDown.getScalar().equals((int)ce.getScalars("cat9").get(0)));
LOGGER.info("passed"); //Just because they did it in the Python version :-)
}
}