/* --------------------------------------------------------------------- * Numenta Platform for Intelligent Computing (NuPIC) * Copyright (C) 2014, Numenta, Inc. Unless you have an agreement * with Numenta, Inc., for a separate license for this software code, the * following terms and conditions apply: * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero Public License version 3 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU Affero Public License for more details. * * You should have received a copy of the GNU Affero Public License * along with this program. If not, see http://www.gnu.org/licenses. * * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ package org.numenta.nupic.encoders; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.List; import java.util.Map; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; /** * @author sambit * */ public class AdaptiveScalarEncoderTest { private AdaptiveScalarEncoder ase; private AdaptiveScalarEncoder.Builder builder; /** * @throws java.lang.Exception */ @BeforeClass public static void setUpBeforeClass() throws Exception { } /** * @throws java.lang.Exception */ @AfterClass public static void tearDownAfterClass() throws Exception { } /** * */ @Before public void setUp() { builder = AdaptiveScalarEncoder.adaptiveBuilder().n(14).w(3).minVal(1) .maxVal(8).radius(1.5).resolution(0.5).periodic(false) .forced(true); } /** * @throws java.lang.Exception */ @After public void tearDown() throws Exception { } private void initASE() { ase = builder.build(); } /** * Test method for * {@link org.numenta.nupic.encoders.AdaptiveScalarEncoder#AdaptiveScalarEncoder()} * . */ @Test public void testAdaptiveScalarEncoder() { setUp(); initASE(); Assert.assertNotNull("AdaptiveScalarEncoder class is null", ase); } @Test public void testInit() { setUp(); initASE(); Assert.assertNotNull("AdaptiveScalarEncoder class is null", ase); ase.setW(3); ase.setMinVal(1); ase.setMaxVal(8); ase.setN(14); ase.setRadius(1.5); ase.setResolution(0.5); ase.setForced(true); ase.init(); } /** * Test method for * {@link org.numenta.nupic.encoders.AdaptiveScalarEncoder#initEncoder(int, double, double, int, double, double)} * . */ @Test public void testInitEncoder() { setUp(); initASE(); ase.initEncoder(3, 1, 8, 14, 1.5, 0.5); Assert.assertNotNull("AdaptiveScalarEncoder class is null", ase); /////////// Negative Test /////////// setUp(); initASE(); Assert.assertNotNull("AdaptiveScalarEncoder class is null", ase); try { ase.setPeriodic(true); // Should cause failure during init ase.initEncoder(3, 1, 8, 14, 1.5, 0.5); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("Adaptive scalar encoder does not encode periodic inputs", e.getMessage()); } } @Test public void testMissingData() { setUp(); initASE(); ase.initEncoder(3, 1, 8, 14, 1.5, 0.5); ase.setName("mv"); ase.setPeriodic(false); int[] empty = ase.encode(Encoder.SENTINEL_VALUE_FOR_MISSING_DATA); System.out.println("\nEncoded missing data as: " + Arrays.toString(empty)); int[] expected = new int[14]; assertTrue(Arrays.equals(expected, empty)); } @Test public void testNonPeriodicEncoderMinMaxSpec() { initASE(); int[] res = ase.encode(1.0); System.out.println("\nEncoded data as: " + Arrays.toString(res)); assertTrue(Arrays.equals(new int[] { 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, res)); res = ase.encode(2.0); System.out.println("\nEncoded data as: " + Arrays.toString(res)); assertTrue(Arrays.equals(new int[] { 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, res)); res = ase.encode(8.0); System.out.println("\nEncoded data as: " + Arrays.toString(res)); assertTrue(Arrays.equals(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 }, res)); } @Test public void testTopDownDecode() { initASE(); double minVal = ase.getMinVal(); System.out.println("\nThe min value is:" + minVal); double resolution = ase.getResolution(); System.out.println(String.format("\nTesting non-periodic encoder decoding, resolution of %f ...", resolution)); double maxVal = ase.getMaxVal(); System.out.println("\nThe max value is:" + maxVal); while(minVal < maxVal) { int[] output = ase.encode(minVal); DecodeResult decoded = ase.decode(output, ""); System.out.println("\nDecoding " + Arrays.toString(output) + String.format("(%f)", minVal) + " => " + decoded.toString()); Map<String, RangeList> fields = decoded.getFields(); Assert.assertEquals("Number of keys not matching", 1, fields.keySet().size()); System.out.println("\nField Key: " + fields.keySet().iterator().next()); Assert.assertEquals("Number of range not matching", 1, fields.get(fields.keySet().iterator().next()).size()); System.out.println("\nField Range Value: " + fields.get(fields.keySet().iterator().next()).get(0)); Assert.assertEquals("Range max and min are not matching", fields.get(fields.keySet().iterator().next()).getRange(0).max(), fields.get(fields.keySet().iterator().next()).getRange(0).min(), 0); assertTrue(Math.abs(fields.get(fields.keySet().iterator().next()).getRange(0).min() - minVal) < ase.getResolution()); java.util.List<Encoding> topDown = ase.topDownCompute(output); assertTrue(topDown.size() == 1); System.out.println("\nTopDown => " + topDown.toString()); int[] bucketIndices = ase.getBucketIndices(minVal); assertTrue("The bucket indice size is not matching", bucketIndices.length == 1); System.out.println("Bucket indices => " + Arrays.toString(bucketIndices)); List<Encoding> bucketInfoList = ase.getBucketInfo(bucketIndices); assertTrue((Math.abs((double)bucketInfoList.get(0).getValue() - minVal)) <= (ase.getResolution() / 2)); System.out.println("Bucket info value: " + bucketInfoList.get(0).getValue()); System.out.println("Minval: " + minVal + " Abs(BucketVal - Minval): " + Math.abs((double)bucketInfoList.get(0).getValue() - minVal)); System.out.println("Resolution: " + ase.getResolution() + " Resolution/2: " + ase.getResolution() / 2); assertTrue((double)bucketInfoList.get(0).getValue() == (double)ase.getBucketValues(Double.class).toArray()[bucketIndices[0]]); System.out.println("\nBucket info scalar: " + bucketInfoList.get(0).getScalar()); System.out.println("\nBucket info value: " + bucketInfoList.get(0).getValue()); assertTrue(bucketInfoList.get(0).getScalar().doubleValue() == (double)bucketInfoList.get(0).getValue()); System.out.println("\nBucket info encoding: " + bucketInfoList.get(0).getEncoding()); System.out.println("\nOriginal encoding: " + Arrays.toString(output)); assertTrue(Arrays.equals(bucketInfoList.get(0).getEncoding(), output)); minVal += resolution / 4; } } @Test public void testFillHoles() { initASE(); int[] inputArray = new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1 }; double minVal = ase.getMinVal(); DecodeResult decoded = ase.decode(inputArray, ""); System.out.println("\nDecoding " + Arrays.toString(inputArray) + String.format("(%f)", minVal) + " => " + decoded.toString()); Map<String, RangeList> fields = decoded.getFields(); assertTrue(fields.size() == 1); Assert.assertEquals("Number of keys not matching", 1, fields.keySet().size()); System.out.println("\nField Key: " + fields.keySet().iterator().next()); Assert.assertEquals("Number of range not matching", 2, fields.get(fields.keySet().iterator().next()).size()); System.out.println("\nField Range Value: " + fields.get(fields.keySet().iterator().next()).get(0)); Assert.assertEquals("Range max and min are not matching", fields.get(fields.keySet().iterator().next()).getRange(0).max(), fields.get(fields.keySet().iterator().next()).getRange(0).min(), 0); assertTrue(fields.get(fields.keySet().iterator().next()).getRange(1).min() == 8.00); assertTrue(fields.get(fields.keySet().iterator().next()).getRange(1).max() == 8.00); int[] newArray = new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1 }; DecodeResult newDecoded = ase.decode(newArray, ""); System.out.println("\nDecoding new array " + Arrays.toString(newArray) + String.format("(%f)", minVal) + " => " + newDecoded.toString()); Map<String, RangeList> newFields = newDecoded.getFields(); assertTrue(newFields.size() == 1); Assert.assertEquals("Number of keys not matching", 1, newFields.keySet().size()); System.out.println("\nField Key: " + newFields.keySet().iterator().next()); Assert.assertEquals("Number of range not matching", 2, newFields.get(newFields.keySet().iterator().next()).size()); System.out.println("\nField Range Value: " + newFields.get(newFields.keySet().iterator().next()).get(0)); Assert.assertEquals("Range max and min are not matching", newFields.get(newFields.keySet().iterator().next()).getRange(0).max(), newFields.get(newFields.keySet().iterator().next()).getRange(0).min(), 0); assertTrue(newFields.get(newFields.keySet().iterator().next()).getRange(1).min() == 8.00); assertTrue(newFields.get(newFields.keySet().iterator().next()).getRange(1).max() == 8.00); } @Test public void testSkippedMinMaxCode() { setUp(); initASE(); ase.setMinVal(ase.getMaxVal()); ase.getBucketIndices(ase.getMaxVal()); assertEquals(1, ase.getRangeInternal(), 0); // ASE enforces minimum range of 1.0 } }