/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import org.junit.Test;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.Tuple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class ScalarEncoderTest {
private ScalarEncoder se;
private ScalarEncoder.Builder builder;
private void setUp() {
builder = ScalarEncoder.builder()
.n(14)
.w(3)
.radius(0.0)
.minVal(1.0)
.maxVal(8.0)
.periodic(true)
.forced(true);
}
private void initSE() {
se = builder.build();
}
@Test
public void testScalarEncoder() {
setUp();
initSE();
int[] empty = se.encode(Encoder.SENTINEL_VALUE_FOR_MISSING_DATA);
System.out.println("\nEncoded missing data as: " + Arrays.toString(empty));
int[] expected = new int[14];
assertTrue(Arrays.equals(expected, empty));
}
@Test
public void testGetScalars() {
setUp();
initSE();
TDoubleList scalars = se.getScalars(42.42d);
assertEquals(42.42d, scalars.get(0), 0.01);
}
@Test
public void testDecodeNull() {
setUp();
initSE();
DecodeResult dr = se.decode(null, "blah");
assertTrue(dr == null);
}
@Test
public void testGetFirstOnBit() {
setUp();
builder.periodic(false);
builder.clipInput(true);
initSE();
int firstOnBit = -1;
try {
firstOnBit = se.getFirstOnBit(Encoder.SENTINEL_VALUE_FOR_MISSING_DATA);
fail();
}catch(Exception e) {
assertEquals(NullPointerException.class, e.getClass());
}
// for value < min
assertTrue(0 == se.getFirstOnBit(0.9));
// Value less than min when clipInput == false || periodic == true
// Should throw an exception
setUp();
builder.periodic(true);
builder.clipInput(true);
initSE();
try {
se.getFirstOnBit(0.9);
fail();
}catch(Exception e) {
assertEquals(IllegalStateException.class, e.getClass());
assertEquals("input (0.9) less than range (1.0 - 8.0)", e.getMessage());
}
// Value greater than max when periodic == true
// Should throw an exception
setUp();
builder.periodic(true);
builder.clipInput(true);
initSE();
try {
se.getFirstOnBit(100);
fail();
}catch(Exception e) {
assertEquals(IllegalStateException.class, e.getClass());
assertEquals("input (100.0) greater than periodic range (1.0 - 8.0)", e.getMessage());
}
// Value greater than max when periodic == false && clipInput == true
// Should throw an exception
setUp();
builder.periodic(false);
builder.clipInput(true);
initSE();
firstOnBit = se.getFirstOnBit(100);
assertTrue(11 == firstOnBit);
// Value greater than max when periodic == false && clipInput == false
// Should throw an exception
setUp();
builder.periodic(false);
builder.clipInput(false);
initSE();
try {
se.getFirstOnBit(100);
fail();
}catch(Exception e) {
assertEquals(IllegalStateException.class, e.getClass());
assertEquals("input (100.0) greater than periodic range (1.0 - 8.0)", e.getMessage());
}
setUp();
initSE();
// Normal
assertTrue(11 == se.getFirstOnBit(7));
}
@Test
public void testBottomUpEncodingPeriodicEncoder() {
setUp();
initSE();
assertEquals("[1:8]", se.getDescription().get(0).get(0));
setUp();
builder.name("scalar");
initSE();
assertEquals("scalar", se.getDescription().get(0).get(0));
int[] res = se.encode(3.0);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(3.1);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(3.5);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(3.6);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(3.7);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(4d);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(1d);
assertTrue(Arrays.equals(new int[] { 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, res));
res = se.encode(1.5);
assertTrue(Arrays.equals(new int[] { 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, res));
res = se.encode(7d);
assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 }, res));
res = se.encode(7.5);
assertTrue(Arrays.equals(new int[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 }, res));
assertEquals(0.5d, se.getResolution(), 0);
assertEquals(1.5d, se.getRadius(), 0);
}
/**
* Test that we get the same encoder when we construct it using resolution
* instead of n
*/
@Test
public void testCreateResolution() {
setUp();
initSE();
List<Tuple> dict = se.dict();
setUp();
builder.resolution(0.5);
initSE();
List<Tuple> compare = se.dict();
assertEquals(dict.toString(), compare.toString());
setUp();
builder.radius(1.5);
initSE();
compare = se.dict();
assertEquals(dict.toString(), compare.toString());
//Negative test
setUp();
builder.resolution(0.5);
initSE();
se.setName("break this");
compare = se.dict();
assertFalse(dict.equals(compare));
}
/**
* Test the input description generation, top-down compute, and bucket
* support on a periodic encoder
*/
@Test
public void testDecodeAndResolution() {
setUp();
builder.name("scalar");
initSE();
double resolution = se.getResolution();
StringBuilder out = new StringBuilder();
for(double v = se.getMinVal();v < se.getMaxVal();v+=(resolution / 4.0d)) {
int[] output = se.encode(v);
DecodeResult decoded = se.decode(output, "");
System.out.println(out.append("decoding ").append(Arrays.toString(output)).append(" (").
append(String.format("%.6f", v)).append(")=> ").append(se.decodedToStr(decoded)));
out.setLength(0);
Map<String, RangeList> fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
RangeList ranges = (RangeList)new ArrayList<RangeList>(fieldsMap.values()).get(0);
assertEquals(1, ranges.size());
assertEquals(ranges.getRange(0).min(), ranges.getRange(0).max(), 0);
assertTrue(ranges.getRange(0).min() - v < se.getResolution());
Encoding topDown = se.topDownCompute(output).get(0);
System.out.println("topdown => " + topDown);
assertTrue(Arrays.equals(topDown.getEncoding(),output));
assertTrue(Math.abs(((double)topDown.get(1)) - v) <= se.getResolution() / 2);
//Test bucket support
int[] bucketIndices = se.getBucketIndices(v);
System.out.println("bucket index => " + bucketIndices[0]);
topDown = se.getBucketInfo(bucketIndices).get(0);
assertTrue(Math.abs(((double)topDown.get(1)) - v) <= se.getResolution() / 2);
assertEquals(topDown.get(1), se.getBucketValues(Double.class).toArray()[bucketIndices[0]]);
assertEquals(topDown.get(2), topDown.get(1));
assertTrue(Arrays.equals(topDown.getEncoding(), output));
}
// -----------------------------------------------------------------------
// Test the input description generation on a large number, periodic encoder
setUp();
builder.name("scalar")
.w(3)
.radius(1.5)
.minVal(1.0)
.maxVal(8.0)
.periodic(true)
.forced(true);
initSE();
System.out.println("\nTesting periodic encoder decoding, resolution of " + se.getResolution());
//Test with a "hole"
int[] encoded = new int[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };
DecodeResult decoded = se.decode(encoded, "");
Map<String, RangeList> fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
assertEquals(1, decoded.getRanges("scalar").size());
assertEquals(decoded.getRanges("scalar").getRange(0).toString(), "7.5, 7.5");
//Test with something wider than w, and with a hole, and wrapped
encoded = new int[] { 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };
decoded = se.decode(encoded, "");
fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
assertEquals(2, decoded.getRanges("scalar").size());
assertEquals(decoded.getRanges("scalar").getRange(0).toString(), "7.5, 8.0");
//Test with something wider than w, no hole
encoded = new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
decoded = se.decode(encoded, "");
fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
assertEquals(1, decoded.getRanges("scalar").size());
assertEquals(decoded.getRanges("scalar").getRange(0).toString(), "1.5, 2.5");
//Test with 2 ranges
encoded = new int[] { 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0 };
decoded = se.decode(encoded, "");
fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
assertEquals(2, decoded.getRanges("scalar").size());
assertEquals(decoded.getRanges("scalar").getRange(0).toString(), "1.5, 1.5");
assertEquals(decoded.getRanges("scalar").getRange(1).toString(), "5.5, 6.0");
//Test with 2 ranges, 1 of which is narrower than w
encoded = new int[] { 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0 };
decoded = se.decode(encoded, "");
fieldsMap = decoded.getFields();
assertEquals(1, fieldsMap.size());
assertEquals(2, decoded.getRanges("scalar").size());
assertEquals(decoded.getRanges("scalar").getRange(0).toString(), "1.5, 1.5");
assertEquals(decoded.getRanges("scalar").getRange(1).toString(), "5.5, 6.0");
}
/**
* Test closenessScores for a periodic encoder
*/
@Test
public void testCloseness() {
setUp();
builder.name("day of week")
.w(7)
.radius(1.0)
.minVal(0.0)
.maxVal(7.0)
.periodic(true)
.forced(true);
initSE();
TDoubleList expValues = new TDoubleArrayList(new double[] { 2, 4, 7 });
TDoubleList actValues = new TDoubleArrayList(new double[] { 4, 2, 1 });
TDoubleList scores = se.closenessScores(expValues, actValues, false);
for(Tuple t : ArrayUtils.zip(Arrays.asList(2, 2, 1), Arrays.asList(scores.get(0)))) {
double a = (int)t.get(0);
double b = (double)t.get(1);
assertTrue(a == b);
}
}
@Test
public void testNonPeriodicBottomUp() {
setUp();
builder.name("day of week")
.w(5)
.n(14)
.radius(1.0)
.minVal(1.0)
.maxVal(10.0)
.periodic(false)
.forced(true);
initSE();
System.out.println(String.format("Testing non-periodic encoder encoding resolution of ", se.getResolution()));
assertTrue(Arrays.equals(se.encode(1d), new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }));
assertTrue(Arrays.equals(se.encode(2d), new int[] { 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }));
assertTrue(Arrays.equals(se.encode(10d), new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 }));
// Test that we get the same encoder when we construct it using resolution
// instead of n
setUp();
builder.name("day of week")
.w(5)
.radius(5.0)
.minVal(1.0)
.maxVal(10.0)
.periodic(false)
.forced(true);
initSE();
double v = se.getMinVal();
while(v < se.getMaxVal()) {
int[] output = se.encode(v);
DecodeResult decoded = se.decode(output, "");
System.out.println("decoding " + Arrays.toString(output) + String.format("(%f)=>", v) + se.decodedToStr(decoded));
assertEquals(decoded.getFields().size(), 1, 0);
List<RangeList> rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(rangeList.get(0).size(), 1, 0);
MinMax minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(Math.abs(minMax.min() - v) <= se.getResolution());
List<Encoding> topDowns = se.topDownCompute(output);
Encoding topDown = topDowns.get(0);
System.out.println("topDown => " + topDown);
assertTrue(Arrays.equals(topDown.getEncoding(),output));
assertTrue(Math.abs(((double)topDown.getValue()) - v) <= se.getResolution());
//Test bucket support
int[] bucketIndices = se.getBucketIndices(v);
System.out.println("bucket index => " + bucketIndices[0]);
topDown = se.getBucketInfo(bucketIndices).get(0);
assertTrue(Math.abs(((double)topDown.getValue()) - v) <= se.getResolution() / 2);
assertEquals(topDown.getScalar(), topDown.getValue());
assertTrue(Arrays.equals(topDown.getEncoding(), output));
// Next value
v += se.getResolution() / 4;
}
// Make sure we can fill in holes
DecodeResult decoded = se.decode(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1 }, "");
assertEquals(decoded.getFields().size(), 1, 0);
List<RangeList> rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
System.out.println("decodedToStr of " + rangeList + " => " + se.decodedToStr(decoded));
decoded = se.decode(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1 }, "");
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(1, rangeList.get(0).size(), 0);
System.out.println("decodedToStr of " + rangeList + " => " + se.decodedToStr(decoded));
// Test min and max
setUp();
builder.name("scalar")
.w(3)
.minVal(1.0)
.maxVal(10.0)
.periodic(false)
.forced(true);
initSE();
List<Encoding> decode = se.topDownCompute(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 });
assertEquals(10, (Double)decode.get(0).getScalar(), 0);
decode = se.topDownCompute(new int[] { 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 });
assertEquals(1, (Double)decode.get(0).getScalar(), 0);
// Make sure only the last and first encoding encodes to max and min, and there is no value greater than max or min
setUp();
builder.name("scalar")
.w(3)
.n(140)
.radius(1.0)
.minVal(1.0)
.maxVal(141.0)
.periodic(false)
.forced(true);
initSE();
List<int[]> iterlist = new ArrayList<int[]>();
for(int i = 0;i < 137;i++) {
iterlist.add(new int[140]);
ArrayUtils.setRangeTo(iterlist.get(i), i, i+3, 1);
decode = se.topDownCompute(iterlist.get(i));
int value = decode.get(0).getScalar().intValue();
assertTrue(value <= 141);
assertTrue(value >= 1);
assertTrue(value < 141 || i==137);
assertTrue(value > 1 || i==0);
}
// -------------------------------------------------------------------------
// Test the input description generation and top-down compute on a small number
// non-periodic encoder
setUp();
builder.name("scalar")
.w(3)
.n(15)
.minVal(.001)
.maxVal(.002)
.periodic(false)
.forced(true);
initSE();
System.out.println(String.format("\nTesting non-periodic encoder decoding resolution of %f...", se.getResolution()));
v = se.getMinVal();
while(v < se.getMaxVal()) {
int[] output = se.encode(v);
decoded = se.decode(output, "");
System.out.println(String.format("decoding (%f)=>", v) + " " + se.decodedToStr(decoded));
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(rangeList.get(0).size(), 1, 0);
MinMax minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(Math.abs(minMax.min() - v) <= se.getResolution());
decode = se.topDownCompute(output);
System.out.println("topdown => " + decode);
assertTrue(Math.abs((Double)decode.get(0).getScalar() - v) <= se.getResolution() / 2);
v += (se.getResolution() / 4);
}
// -------------------------------------------------------------------------
// Test the input description generation on a large number, non-periodic encoder
setUp();
builder.name("scalar")
.w(3)
.n(15)
.minVal(1.0)
.maxVal(1000000000.0)
.periodic(false)
.forced(true);
initSE();
System.out.println(String.format("\nTesting non-periodic encoder decoding resolution of %f...", se.getResolution()));
v = se.getMinVal();
while(v < se.getMaxVal()) {
int[] output = se.encode(v);
decoded = se.decode(output, "");
System.out.println(String.format("decoding (%f)=>", v) + " " + se.decodedToStr(decoded));
assertEquals(decoded.getFields().size(), 1, 0);
rangeList = new ArrayList<RangeList>(decoded.getFields().values());
assertEquals(rangeList.get(0).size(), 1, 0);
MinMax minMax = rangeList.get(0).getRanges().get(0);
assertEquals(minMax.min(), minMax.max(), 0);
assertTrue(Math.abs(minMax.min() - v) <= se.getResolution());
decode = se.topDownCompute(output);
System.out.println("topdown => " + decode);
assertTrue(Math.abs((Double)decode.get(0).getScalar() - v) <= se.getResolution() / 2);
v += (se.getResolution() / 4);
}
}
/**
* This should not cause an OutOfMemoryError due to no resolution being set.
* Fix for #142 (see: https://github.com/numenta/htm.java/issues/142)
*/
@Test
public void endlessLoopInTopDownCompute() {
ScalarEncoder encoder = ScalarEncoder.builder()
.w( 5 )
.n( 10 )
.forced( true )
.minVal( 0 )
.maxVal( 100 )
.build();
encoder.topDownCompute( new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 } );
}
}