package org.numenta.nupic.encoders; import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.startsWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.numenta.nupic.FieldMetaType; import org.numenta.nupic.util.Tuple; /** * Unit tests for RandomDistributedScalarEncoder class * * @author Anubhav Chaturvedi * */ public class RandomDistributedScalarEncoderTest { private RandomDistributedScalarEncoder rdse; private RandomDistributedScalarEncoder.Builder builder; @Rule public ExpectedException exception = ExpectedException.none(); /** * Test basic encoding functionality. Create encodings without crashing and * check they contain the correct number of on and off bits. Check some * encodings for expected overlap. Test that encodings for old values don't * change once we generate new buckets. */ @Test public void testEncoding() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .w(23) .n(500) .setOffset(0); rdse = builder.build(); int e0[] = rdse.encode(-0.1); assertEquals("Number of on bits is incorrect", getOnBits(e0), 23); assertEquals("Width of the vector is incorrect", e0.length, 500); assertEquals("Offset doesn't correspond to middle bucket", rdse.getBucketIndices(0)[0], rdse.getMaxBuckets() / 2); assertEquals("Number of buckets is not 1", 1, rdse.bucketMap.size()); // Encode with a number that is resolution away from offset. Now we should // have two buckets and this encoding should be one bit away from e0 int e1[] = rdse.encode(1.0); assertEquals("Number of buckets is not 2", 2, rdse.bucketMap.size()); assertEquals("Number of on bits is incorrect", getOnBits(e1), 23); assertEquals("Width of the vector is incorrect", e0.length, 500); assertEquals("Overlap is not equal to w-1", computeOverlap(e0, e1), 22); // Encode with a number that is resolution*w away from offset. Now we should // have many buckets and this encoding should have very little overlap with // e0 int e25[] = rdse.encode(25.0); assertTrue("Buckets created are not more than 23", rdse.bucketMap.size() > 23); assertEquals("Number of on bits is incorrect", getOnBits(e1), 23); assertEquals("Width of the vector is incorrect", e0.length, 500); assertTrue("Overlap is too high", computeOverlap(e0, e25) < 4); // Test encoding consistency. The encodings for previous numbers // shouldn't change even though we have added additional buckets assertThat( "Encodings are not consistent - they have changed after new buckets have been created", rdse.encode(-0.1), is(equalTo(e0))); assertThat( "Encodings are not consistent - they have changed after new buckets have been created", rdse.encode(1.0), is(equalTo(e1))); } /** * Test that missing values and NaN return all zero's. */ @Test public void testMissingValues() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1); rdse = builder.build(); int[] e2 = rdse.encode(Double.NaN); assertEquals(0, getOnBits(e2)); int[] e1 = rdse.encode(Encoder.SENTINEL_VALUE_FOR_MISSING_DATA); assertEquals(0, getOnBits(e1)); } /** * Test that numbers within the same resolution return the same encoding. * Numbers outside the resolution should return different encodings. */ @Test public void testResolution() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1); rdse = builder.build(); // Since 23.0 is the first encoded number, it will be the offset. // Since resolution is 1, 22.9 and 23.4 should have the same bucket index and // encoding. int[] e23 = rdse.encode(23.0); int[] e23_1 = rdse.encode(23.1); int[] e22_9 = rdse.encode(22.9); int[] e24 = rdse.encode(24.0); assertEquals(rdse.getW(), getOnBits(e23)); assertThat("Numbers within resolution don't have the same encoding", e23_1, is(equalTo(e23))); assertThat("Numbers within resolution don't have the same encoding", e22_9, is(equalTo(e23))); assertThat("Numbers outside resolution have the same encoding", e23, is(not(equalTo(e24)))); int[] e22_5 = rdse.encode(22.5); assertThat("Numbers outside resolution have the same encoding", e23, is(not(equalTo(e22_5)))); } /** * Test that mapBucketIndexToNonZeroBits works and that max buckets and * clipping are handled properly. */ @Test public void testMapBucketIndexToNonZeroBits() { builder = RandomDistributedScalarEncoder.builder() .resolution(1) .w(11) .n(150); rdse = builder.build(); // Set a low number of max buckets rdse.initializeBucketMap(10, null); rdse.encode(0.0); rdse.encode(-7.0); rdse.encode(7.0); assertEquals("maxBuckets exceeded", rdse.getMaxBuckets(), rdse.bucketMap.size()); assertThat("mapBucketIndexToNonZeroBits did not handle negative index", rdse.mapBucketIndexToNonZeroBits(-1), is(equalTo(rdse.bucketMap.get(0)))); assertThat("mapBucketIndexToNonZeroBits did not handle negative index", rdse.mapBucketIndexToNonZeroBits(1000), is(equalTo(rdse.bucketMap.get(9)))); int[] e23 = rdse.encode(23.0); int[] e6 = rdse.encode(6.0); assertThat("Values not clipped correctly during encoding", e23, is(equalTo(e6))); int[] e_8 = rdse.encode(-8.0); int[] e_7 = rdse.encode(-7.0); assertThat("Values not clipped correctly during encoding", e_8, is(equalTo(e_7))); assertEquals("getBucketIndices returned negative bucket index", 0, rdse.getBucketIndices(-8.0)[0]); assertEquals("getBucketIndices returned negative bucket index", rdse.getMaxBuckets() - 1, rdse.getBucketIndices(23.0)[0]); } /** * @author Sean Connolly */ @Test public void testParameterCheckWithInvalidN() { // n must be >= 6 * w exception.expect(IllegalStateException.class); exception.expectMessage("n must be strictly greater than 6*w. For good results we recommend n be strictly greater than 11*w."); RandomDistributedScalarEncoder.builder() .n((int) (5.9 * 21)) .w(21) .resolution(1) .build(); } /** * @author Sean Connolly */ @Test public void testParameterCheckWithInvalidW() { // w can 't be negative exception.expect(IllegalStateException.class); exception.expectMessage("W must be an odd positive integer (to eliminate centering difficulty)"); RandomDistributedScalarEncoder.builder() .n(500) .w(6) .resolution(2) .build(); } /** * @author Sean Connolly */ @Test public void testParameterCheckWithInvalidResolution() { // resolution can 't be negative exception.expect(IllegalStateException.class); exception.expectMessage("Resolution must be a positive number"); RandomDistributedScalarEncoder.builder() .n(500) .w(5) .resolution(-1) .build(); } /** * Check that the overlaps for the encodings are within the expected range. * Here we ask the encoder to create a bunch of representations under somewhat * stressfull conditions, and then verify they are correct. We rely on the fact * that the _overlapOK and _countOverlapIndices methods are working correctly. */ @Test public void testOverlapStatistics() { builder = RandomDistributedScalarEncoder.builder() .resolution(1) .w(11) .n(150) .setSeed(RandomDistributedScalarEncoder.DEFAULT_SEED); rdse = builder.build(); rdse.encode(0.0); rdse.encode(-300.0); rdse.encode(300.0); assertTrue("Illegal overlap encountered in encoder", validateEncoder(rdse, 3)); } /** * Test that the getWidth, getDescription, and getDecoderOutputFieldTypes * methods work. */ @Test public void testGetMethods() { builder = RandomDistributedScalarEncoder.builder() .name("theName") .resolution(1) .n(500); rdse = builder.build(); assertEquals("getWidth doesn't return the correct result", 500, rdse.getWidth()); assertEquals( "getDescription doesn't return the correct result", new ArrayList<Tuple>(Arrays.asList(new Tuple[] { new Tuple("theName", 0) })), rdse.getDescription()); assertThat( "getDecoderOutputFieldTypes doesn't return the correct result", rdse.getDecoderOutputFieldTypes(), is(equalTo(new LinkedHashSet<>(Arrays.asList(FieldMetaType.FLOAT, FieldMetaType.INTEGER))))); } /** * Test that offset is working properly */ @Test public void testOffset() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1); rdse = builder.build(); rdse.encode(23.0); assertEquals( "Offset not initialized to specified constructor parameter", 23, rdse.getOffset(), 0); builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .setOffset(25.0); rdse = builder.build(); rdse.encode(23.0); assertEquals( "Offset not initialized to specified constructor parameter", 25, rdse.getOffset(), 0); } @Test public void testSeed() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1); RandomDistributedScalarEncoder encoder1 = builder.setSeed(42).build(); RandomDistributedScalarEncoder encoder2 = builder.setSeed(42).build(); RandomDistributedScalarEncoder encoder3 = builder.setSeed(-2).build(); //RandomDistributedScalarEncoder encoder4 = builder.setSeed(-1).build(); int[] e1 = encoder1.encode(23.0); int[] e2 = encoder2.encode(23.0); int[] e3 = encoder3.encode(23.0); //int[] e4 = encoder4.encode(23.0); assertThat("Same seed gives rise to different encodings", e1, is(equalTo(e2))); assertThat("Different seeds gives rise to same encodings", e1, is(not(equalTo(e3)))); //Removing this test because testing the RNG is not part of the scope of //this test - and we cannot assure that the RNG will initialize the default //seed to different values. //assertThat("seeds of -1 give rise to same encodings", e4, // is(not(equalTo(e3)))); } /** * Test that the internal method _countOverlapIndices works as expected. */ @Test public void testCountOverlapIndices() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .w(5) .n(5 * 20); rdse = builder.build(); // Create a fake set of encodings. int midIdx = rdse.getMaxBuckets() / 2; rdse.bucketMap.put(midIdx - 2, getRangeAsList(3, 8)); rdse.bucketMap.put(midIdx - 1, getRangeAsList(4, 9)); rdse.bucketMap.put(midIdx, getRangeAsList(5, 10)); rdse.bucketMap.put(midIdx + 1, getRangeAsList(6, 11)); rdse.bucketMap.put(midIdx + 2, getRangeAsList(7, 12)); rdse.bucketMap.put(midIdx + 3, getRangeAsList(8, 13)); rdse.minIndex = midIdx - 2; rdse.maxIndex = midIdx + 3; // Test some overlaps assertEquals("countOverlapIndices didn't work", 5, rdse.countOverlapIndices(midIdx - 2, midIdx - 2)); assertEquals("countOverlapIndices didn't work", 4, rdse.countOverlapIndices(midIdx - 1, midIdx - 2)); assertEquals("countOverlapIndices didn't work", 2, rdse.countOverlapIndices(midIdx + 1, midIdx - 2)); assertEquals("countOverlapIndices didn't work", 0, rdse.countOverlapIndices(midIdx - 2, midIdx + 3)); } @Test public void testCountOverlapIndicesWithWrongIndices_i_j() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .w(5) .n(5 * 20); rdse = builder.build(); int midIdx = rdse.getMaxBuckets() / 2; rdse.bucketMap.put(midIdx - 2, getRangeAsList(3, 8)); rdse.bucketMap.put(midIdx - 1, getRangeAsList(4, 9)); rdse.bucketMap.put(midIdx, getRangeAsList(5, 10)); rdse.bucketMap.put(midIdx + 1, getRangeAsList(6, 11)); rdse.bucketMap.put(midIdx + 2, getRangeAsList(7, 12)); rdse.bucketMap.put(midIdx + 3, getRangeAsList(8, 13)); rdse.minIndex = midIdx - 2; rdse.maxIndex = midIdx + 3; exception.expect(IllegalStateException.class); exception.expectMessage( allOf( startsWith("index"), endsWith("don't exist") ) ); rdse.countOverlapIndices(midIdx - 3, midIdx - 4); } @Test public void testCountOverlapIndicesWithWrongIndices_i() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .w(5) .n(5 * 20); rdse = builder.build(); int midIdx = rdse.getMaxBuckets() / 2; rdse.bucketMap.put(midIdx - 2, getRangeAsList(3, 8)); rdse.bucketMap.put(midIdx - 1, getRangeAsList(4, 9)); rdse.bucketMap.put(midIdx, getRangeAsList(5, 10)); rdse.bucketMap.put(midIdx + 1, getRangeAsList(6, 11)); rdse.bucketMap.put(midIdx + 2, getRangeAsList(7, 12)); rdse.bucketMap.put(midIdx + 3, getRangeAsList(8, 13)); rdse.minIndex = midIdx - 2; rdse.maxIndex = midIdx + 3; exception.expect(IllegalStateException.class); exception.expectMessage( allOf( startsWith("index"), endsWith("doesn't exist") ) ); rdse.countOverlapIndices(midIdx - 3, midIdx - 2); } /** * Test that the internal method {@link RandomDistributedScalarEncoder#overlapOK(int, int)} * works as expected. */ @Test public void testOverlapOK() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .w(5) .n(5 * 20); rdse = builder.build(); int midIdx = rdse.getMaxBuckets() / 2; rdse.bucketMap.put(midIdx - 3, getRangeAsList(4, 9)); rdse.bucketMap.put(midIdx - 2, getRangeAsList(3, 8)); rdse.bucketMap.put(midIdx - 1, getRangeAsList(4, 9)); rdse.bucketMap.put(midIdx, getRangeAsList(5, 10)); rdse.bucketMap.put(midIdx + 1, getRangeAsList(6, 11)); rdse.bucketMap.put(midIdx + 2, getRangeAsList(7, 12)); rdse.bucketMap.put(midIdx + 3, getRangeAsList(8, 13)); rdse.minIndex = midIdx - 3; rdse.maxIndex = midIdx + 3; assertTrue("overlapOK didn't work", rdse.overlapOK(midIdx, midIdx - 1)); assertTrue("overlapOK didn't work", rdse.overlapOK(midIdx - 2, midIdx + 3)); assertFalse("overlapOK didn't work", rdse.overlapOK(midIdx - 3, midIdx - 1)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(100, 50, 0)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(100, 50, rdse.getMaxOverlap())); assertFalse("overlapOK didn't work for far values", rdse.overlapOK(100, 50, rdse.getMaxOverlap() + 1)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(50, 50, 5)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(48, 50, 3)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(46, 50, 1)); assertTrue("overlapOK didn't work for far values", rdse.overlapOK(45, 50, rdse.getMaxOverlap())); assertFalse("overlapOK didn't work for far values", rdse.overlapOK(48, 50, 4)); assertFalse("overlapOK didn't work for far values", rdse.overlapOK(48, 50, 2)); assertFalse("overlapOK didn't work for far values", rdse.overlapOK(46, 50, 2)); assertFalse("overlapOK didn't work for far values", rdse.overlapOK(50, 50, 6)); } @Test public void testCountOverlap() { builder = RandomDistributedScalarEncoder.builder() .name("enc") .resolution(1) .n(500); rdse = builder.build(); int[] r1 = new int[] { 1, 2, 3, 4, 5, 6 }; int[] r2 = new int[] { 1, 2, 3, 4, 5, 6 }; assertEquals("countOverlap result is incorrect", 6, rdse.countOverlap(r1, r2)); r1 = new int[] { 1, 2, 3, 4, 5, 6 }; r2 = new int[] { 1, 2, 3, 4, 5, 7 }; assertEquals("countOverlap result is incorrect", 5, rdse.countOverlap(r1, r2)); r1 = new int[] { 1, 2, 3, 4, 5, 6 }; r2 = new int[] { 6, 5, 4, 3, 2, 1 }; assertEquals("countOverlap result is incorrect", 6, rdse.countOverlap(r1, r2)); r1 = new int[] { 1, 2, 8, 4, 5, 6 }; r2 = new int[] { 1, 2, 3, 4, 9, 6 }; assertEquals("countOverlap result is incorrect", 4, rdse.countOverlap(r1, r2)); r1 = new int[] { 1, 2, 3, 4, 5, 6 }; r2 = new int[] { 1, 2, 3 }; assertEquals("countOverlap result is incorrect", 3, rdse.countOverlap(r1, r2)); r1 = new int[] { 7, 8, 9, 10, 11, 12 }; r2 = new int[] { 1, 2, 3, 4, 5, 6 }; assertEquals("countOverlap result is incorrect", 0, rdse.countOverlap(r1, r2)); } private List<Integer> getRangeAsList(int lowerBound, int upperBound) { if (lowerBound > upperBound) return null; Integer[] arr = new Integer[upperBound - lowerBound]; for (int i = lowerBound; i < upperBound; i++) { arr[i - lowerBound] = i; } return Arrays.asList(arr); } @Test public void testGetOnBitsMethod() { int input1[] = new int[] {1,0,0,0,1}; int input2[] = new int[] {1,0,2,0,1}; assertEquals("getOnBits returned wrong value ", 2, getOnBits(input1)); assertEquals("getOnBits did not return -1 for invalid input", -1, getOnBits(input2)); } private boolean validateEncoder(RandomDistributedScalarEncoder encoder, int subsampling) { for (int i = encoder.minIndex; i <= encoder.maxIndex; i++) { for (int j = i + 1; j <= encoder.maxIndex; j += subsampling) { if (!encoder.overlapOK(i, j)) return false; } } return true; } private int computeOverlap(int[] result1, int[] result2) { if (result1.length != result2.length) return Integer.MIN_VALUE; int overlap = 0; for (int i = 0; i < result1.length; i++) if (result1[i] == 1 && result2[i] == 1) overlap++; return overlap; } private int getOnBits(int[] input) { int onBits = 0; for (int i : input) { if( i == 1 ) onBits += 1; else if( i != 0 ) return -1; } return onBits; } }