/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Tuple;
public class CoordinateEncoderTest {
private CoordinateEncoder ce;
private CoordinateEncoder.Builder builder;
private boolean verbose;
private void setUp() {
builder = CoordinateEncoder.builder()
.name("coordinate")
.n(33)
.w(3);
}
private void initCE() {
ce = builder.build();
}
@Test
public void testInvalidW() {
setUp();
initCE();
// Even
try {
setUp();
builder.n(45);
builder.w(4);
//Should fail here
initCE();
fail();
}catch(Exception e) {
assertEquals("w must be odd, and must be a positive integer", e.getMessage());
}
// 0
try {
setUp();
builder.n(45);
builder.w(0);
//Should fail here
initCE();
fail();
}catch(Exception e) {
assertEquals("w must be odd, and must be a positive integer", e.getMessage());
}
// Negative
try {
setUp();
builder.n(45);
builder.w(-2);
//Should fail here
initCE();
fail();
}catch(Exception e) {
assertEquals("w must be odd, and must be a positive integer", e.getMessage());
}
}
@Test
public void testInvalidN() {
setUp();
initCE();
// Even
try {
setUp();
builder.n(11);
builder.w(3);
//Should fail here
initCE();
fail();
}catch(Exception e) {
assertEquals("n must be an int strictly greater than 6*w. For " +
"good results we recommend n be strictly greater than 11*w", e.getMessage());
}
}
@Test
public void testOrderForCoordinate() {
CoordinateEncoder c = new CoordinateEncoder();
double h1 = c.orderForCoordinate(new int[] { 2, 5, 10 });
double h2 = c.orderForCoordinate(new int[] { 2, 5, 11 });
double h3 = c.orderForCoordinate(new int[] { 2497477, -923478 });
assertTrue(0 <= h1 && h1 < 1);
assertTrue(0 <= h2 && h2 < 1);
assertTrue(0 <= h3 && h3 < 1);
System.out.println(h1 + ", " + h2 + ", " + h3);
assertTrue(h1 != h2);
assertTrue(h2 != h3);
}
@Test
public void testBitForCoordinate() {
int n = 1000;
double b1 = CoordinateEncoder.bitForCoordinate(new int[] { 2, 5, 10 }, n);
double b2 = CoordinateEncoder.bitForCoordinate(new int[] { 2, 5, 11 }, n);
double b3 = CoordinateEncoder.bitForCoordinate(new int[] { 2497477, -923478 }, n);
assertTrue(0 <= b1 && b1 < n);
assertTrue(0 <= b2 && b2 < n);
assertTrue(0 <= b3 && b3 < n);
assertTrue(b1 != b2);
assertTrue(b2 != b3);
// Small n
n = 2;
double b4 = CoordinateEncoder.bitForCoordinate(new int[] { 5, 10 }, n);
assertTrue(0 <= b4 && b4 < n);
}
@Test
public void testTopWCoordinates() {
final int[][] coordinates = new int[][] { { 1 }, { 2 }, { 3 }, { 4 }, { 5 } };
CoordinateOrder mock = new CoordinateOrder() {
@Override public double orderForCoordinate(int[] coordinate) {
return ArrayUtils.sum(coordinate) / 5.0d;
}
};
int[][] top = new CoordinateEncoder().topWCoordinates(mock, coordinates, 2);
assertEquals(2, top.length);
assertTrue(Arrays.equals(new int[] { 4 } , top[0]));
assertTrue(Arrays.equals(new int[] { 5 } , top[1]));
}
@Test
public void testNeighbors1D() {
CoordinateEncoder ce = new CoordinateEncoder();
int[] coordinate = new int[] { 100 };
int radius = 5;
List<int[]> neighbors = ce.neighbors(coordinate, radius);
assertEquals(11, neighbors.size());
assertTrue(Arrays.equals(new int[] { 95 }, neighbors.get(0)));
assertTrue(Arrays.equals(new int[] { 100 }, neighbors.get(5)));
assertTrue(Arrays.equals(new int[] { 105 }, neighbors.get(10)));
}
@Test
public void testNeighbors2D() {
CoordinateEncoder ce = new CoordinateEncoder();
int[] coordinate = new int[] { 100, 200 };
int radius = 5;
List<int[]> neighbors = ce.neighbors(coordinate, radius);
assertEquals(121, neighbors.size());
assertTrue(ArrayUtils.contains(new int[] { 95, 195 }, neighbors));
assertTrue(ArrayUtils.contains(new int[] { 95, 205 }, neighbors));
assertTrue(ArrayUtils.contains(new int[] { 100, 200 }, neighbors));
assertTrue(ArrayUtils.contains(new int[] { 105, 195 }, neighbors));
assertTrue(ArrayUtils.contains(new int[] { 105, 205 }, neighbors));
}
@Test
public void testNeighbors0Radius() {
CoordinateEncoder ce = new CoordinateEncoder();
int[] coordinate = new int[] { 100, 200, 300 };
int radius = 0;
List<int[]> neighbors = ce.neighbors(coordinate, radius);
assertEquals(1, neighbors.size());
assertTrue(ArrayUtils.contains(new int[] { 100, 200, 300 }, neighbors));
}
@Test
public void testEncodeIntoArray() {
setUp();
builder.n(33);
builder.w(3);
initCE();
int[] coordinate = new int[] { 100, 200 };
int[] output1 = encode(ce, coordinate, 5);
assertEquals(ArrayUtils.sum(output1), ce.w);
int[] output2 = encode(ce, coordinate, 5);
assertTrue(Arrays.equals(output1, output2));
}
@Test
public void testEncodeSaturateArea() {
setUp();
builder.n(1999);
builder.w(25);
builder.radius(2);
initCE();
int[] outputA = encode(ce, new int[] { 0, 0 }, 2);
int[] outputB = encode(ce, new int[] { 0, 1 }, 2);
assertEquals(0.8, overlap(outputA, outputB), 0.019);
}
/**
* As you get farther from a coordinate, the overlap should decrease
*/
@Test
public void testEncodeRelativePositions() {
// As you get farther from a coordinate, the overlap should decrease
double[] overlaps = overlapsForRelativeAreas(999, 25, new int[] {100, 200}, 10,
new int[] {2, 2}, 0, 5, false);
assertDecreasingOverlaps(overlaps);
}
/**
* As radius increases, the overlap should decrease
*/
@Test
public void testEncodeRelativeRadii() {
// As radius increases, the overlap should decrease
double[] overlaps = overlapsForRelativeAreas(999, 25, new int[] {100, 200}, 5,
null, 1, 5, false);
assertDecreasingOverlaps(overlaps);
// As radius decreases, the overlap should decrease
overlaps = overlapsForRelativeAreas(999, 25, new int[] {100, 200}, 20,
null, -2, 5, false);
assertDecreasingOverlaps(overlaps);
}
/**
* As radius increases, the overlap should decrease
*/
@Test
public void testEncodeRelativePositionsAndRadii() {
// As radius increases and positions change, the overlap should decrease
double[] overlaps = overlapsForRelativeAreas(999, 25, new int[] {100, 200}, 5,
new int[] { 1, 1}, 1, 5, false);
assertDecreasingOverlaps(overlaps);
}
@Test
public void testEncodeUnrelatedAreas() {
double avgThreshold = 0.3;
double maxThreshold = 0.14;
double[] overlaps = overlapsForUnrelatedAreas(1499, 37, 5, 100, false);
assertTrue(ArrayUtils.max(overlaps) < maxThreshold);
assertTrue(ArrayUtils.average(overlaps) < avgThreshold);
maxThreshold = 0.12;
overlaps = overlapsForUnrelatedAreas(1499, 37, 10, 100, false);
assertTrue(ArrayUtils.max(overlaps) < maxThreshold);
assertTrue(ArrayUtils.average(overlaps) < avgThreshold);
maxThreshold = 0.13;
overlaps = overlapsForUnrelatedAreas(999, 25, 10, 100, false);
assertTrue(ArrayUtils.max(overlaps) < maxThreshold);
assertTrue(ArrayUtils.average(overlaps) < avgThreshold);
maxThreshold = 0.16;
overlaps = overlapsForUnrelatedAreas(499, 13, 10, 100, false);
assertTrue(ArrayUtils.max(overlaps) < maxThreshold);
assertTrue(ArrayUtils.average(overlaps) < avgThreshold);
}
@Test
public void testEncodeAdjacentPositions() {
int repetitions = 100;
int n = 999;
int w = 25;
int radius = 10;
double minThreshold = 0.75;
double avgThreshold = 0.90;
double[] allOverlaps = new double[repetitions];
for(int i = 0;i < repetitions;i++) {
double[] overlaps = overlapsForRelativeAreas(
n, w, new int[] { i * 10, i * 10 }, radius, new int[] { 0, 1 }, 0, 1, false);
allOverlaps[i] = overlaps[0];
}
assertTrue(ArrayUtils.min(allOverlaps) > minThreshold);
assertTrue(ArrayUtils.average(allOverlaps) > avgThreshold);
if(verbose) {
System.out.println(String.format("===== Adjacent positions overlap " +
"(n = {0}, w = {1}, radius = {2} ===", n, w, radius));
System.out.println(String.format("Max: {0}", ArrayUtils.max(allOverlaps)));
System.out.println(String.format("Min: {0}", ArrayUtils.min(allOverlaps)));
System.out.println(String.format("Average: {0}", ArrayUtils.average(allOverlaps)));
}
}
public void assertDecreasingOverlaps(double[] overlaps) {
assertEquals(0,
ArrayUtils.sum(
ArrayUtils.where(
ArrayUtils.diff(overlaps), ArrayUtils.GREATER_THAN_0)));
}
public int[] encode(CoordinateEncoder encoder, int[] coordinate, double radius) {
int[] output = new int[encoder.getWidth()];
encoder.encodeIntoArray(new Tuple(coordinate, radius), output);
return output;
}
public double overlap(int[] sdr1, int[] sdr2) {
assertEquals(sdr1.length, sdr2.length);
int sum = ArrayUtils.sum(ArrayUtils.and(sdr1, sdr2));
// System.out.println("and = " + Arrays.toString(ArrayUtils.where(ArrayUtils.and(sdr1, sdr2), ArrayUtils.WHERE_1)));
// System.out.println("sum = " + ArrayUtils.sum(ArrayUtils.and(sdr1, sdr2)));
return (double)sum / (double)ArrayUtils.sum(sdr1);
}
public double[] overlapsForRelativeAreas(int n, int w, int[] initPosition, int initRadius,
int[] dPosition, int dRadius, int num, boolean verbose) {
setUp();
builder.n(n);
builder.w(w);
initCE();
double[] overlaps = new double[num];
int[] outputA = encode(ce, initPosition, initRadius);
int[] newPosition;
for(int i = 0;i < num;i++) {
newPosition = dPosition == null ? initPosition :
ArrayUtils.i_add(
newPosition = Arrays.copyOf(initPosition, initPosition.length),
ArrayUtils.multiply(dPosition, (i + 1)));
int newRadius = initRadius + (i + 1) * dRadius;
int[] outputB = encode(ce, newPosition, newRadius);
overlaps[i] = overlap(outputA, outputB);
}
return overlaps;
}
public double[] overlapsForUnrelatedAreas(int n, int w, int radius, int repetitions, boolean verbose) {
return overlapsForRelativeAreas(n, w, new int[] { 0, 0 }, radius,
new int[] { 0, radius * 10 }, 0, repetitions, verbose);
}
@Test
public void testTopStrict() {
int[][] input = new int[][]
{{ 95, 195 },
{ 95, 196 },
{ 95, 197 },
{ 95, 198 },
{ 95, 199 },
{ 95, 200 },
{ 95, 201 },
{ 95, 202 },
{ 95, 203 },
{ 95, 204 },
{ 95, 205 },
{ 96, 195 },
{ 96, 196 },
{ 96, 197 },
{ 96, 198 },
{ 96, 199 },
{ 96, 200 },
{ 96, 201 },
{ 96, 202 },
{ 96, 203 },
{ 96, 204 },
{ 96, 205 },
{ 97, 195 },
{ 97, 196 },
{ 97, 197 },
{ 97, 198 },
{ 97, 199 },
{ 97, 200 },
{ 97, 201 },
{ 97, 202 },
{ 97, 203 },
{ 97, 204 },
{ 97, 205 },
{ 98, 195 },
{ 98, 196 },
{ 98, 197 },
{ 98, 198 },
{ 98, 199 },
{ 98, 200 },
{ 98, 201 },
{ 98, 202 },
{ 98, 203 },
{ 98, 204 },
{ 98, 205 },
{ 99, 195 },
{ 99, 196 },
{ 99, 197 },
{ 99, 198 },
{ 99, 199 },
{ 99, 200 },
{ 99, 201 },
{ 99, 202 },
{ 99, 203 },
{ 99, 204 },
{ 99, 205 },
{100, 195 },
{100, 196 },
{100, 197 },
{100, 198 },
{100, 199 },
{100, 200 },
{100, 201 },
{100, 202 },
{100, 203 },
{100, 204 },
{100, 205 },
{101, 195 },
{101, 196 },
{101, 197 },
{101, 198 },
{101, 199 },
{101, 200 },
{101, 201 },
{101, 202 },
{101, 203 },
{101, 204 },
{101, 205 },
{102, 195 },
{102, 196 },
{102, 197 },
{102, 198 },
{102, 199 },
{102, 200 },
{102, 201 },
{102, 202 },
{102, 203 },
{102, 204 },
{102, 205 },
{103, 195 },
{103, 196 },
{103, 197 },
{103, 198 },
{103, 199 },
{103, 200 },
{103, 201 },
{103, 202 },
{103, 203 },
{103, 204 },
{103, 205 },
{104, 195 },
{104, 196 },
{104, 197 },
{104, 198 },
{104, 199 },
{104, 200 },
{104, 201 },
{104, 202 },
{104, 203 },
{104, 204 },
{104, 205 },
{105, 195 },
{105, 196 },
{105, 197 },
{105, 198 },
{105, 199 },
{105, 200 },
{105, 201 },
{105, 202 },
{105, 203 },
{105, 204 },
{105, 205 } };
CoordinateEncoder c = new CoordinateEncoder();
int[][] results = c.topWCoordinates(c, input, 3);
int[][] expected = new int[][] { {95, 200}, {99, 202}, {102, 198} };
for(int i = 0;i < results.length;i++) {
assertTrue(Arrays.equals(results[i], expected[i]));
}
System.out.println("done");
}
}