package org.numenta.nupic.encoders; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import org.joda.time.DateTime; import org.junit.Test; import org.numenta.nupic.Parameters.KEY; import org.numenta.nupic.network.NetworkTestHarness; import org.numenta.nupic.util.Tuple; /** * case "CategoryEncoder": return CategoryEncoder.builder(); case "CoordinateEncoder": return CoordinateEncoder.builder(); case "GeospatialCoordinateEncoder": return GeospatialCoordinateEncoder.geobuilder(); case "LogEncoder": return LogEncoder.builder(); case "PassThroughEncoder": return PassThroughEncoder.builder(); case "ScalarEncoder": return ScalarEncoder.builder(); case "SparsePassThroughEncoder": return SparsePassThroughEncoder.sparseBuilder(); case "SDRCategoryEncoder": return SDRCategoryEncoder.builder(); case "RandomDistributedScalarEncoder": return RandomDistributedScalarEncoder.builder(); case "DateEncoder": return DateEncoder.builder(); case "DeltaEncoder": return DeltaEncoder.deltaBuilder(); case "SDRPassThroughEncoder" : return SDRPassThroughEncoder.sptBuilder(); default: throw new IllegalArgumentException("Invalid encoder: " + encoderName); * @author cogmission * */ public class MultiEncoderAssemblerTest { @Test public void testAssembleCategoryEncoder() { // NetworkTestHarness.setupMap(map, n, w, min, max, radius, resolution, periodic, // clip, forced, fieldName, fieldType, encoderType) Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 3, // w 0, // min 8.0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "string", // fieldType "CategoryEncoder"); // encoderType System.out.println("map = " + settings); // Attempt build omitting the Category list (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("Category List cannot be null", e.getMessage()); } // Failure specifies additional fields which must be set String[] categories = new String[] { "ES", "GB", "US" }; settings.get("testfield").put("categoryList", Arrays.<String>asList(categories)); try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("One of n, radius, resolution must be specified for a CategoryEncoder", e.getMessage()); } // Should now work settings.get("testfield").put("radius", 1.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) { fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(CategoryEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleSDRCategoryEncoder() { String[] categories = {"ES", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18", "S19", "GB", "US"}; Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 100, // n 10, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "string", // fieldType "SDRCategoryEncoder"); // encoderType // Attempt build omitting the Category list (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("Category List cannot be null", e.getMessage()); } // Failure specifies additional fields which must be set settings.get("testfield").put("n", 0); settings.get("testfield").put("w", 0); settings.get("testfield").put("categoryList", Arrays.<String>asList(categories)); try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("\"N\" should be set", e.getMessage()); } // Should now work settings.get("testfield").put("n", 100); settings.get("testfield").put("w", 10); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) { fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(SDRCategoryEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleCoordinateEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "coord", // fieldType "CoordinateEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalArgumentException.class, e.getClass()); assertEquals("w must be odd, and must be a positive integer", e.getMessage()); } // Should pass settings.get("testfield").put("n", 33); settings.get("testfield").put("w", 3); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) { fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(CoordinateEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleGeospatialCoordinateEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "coord", // fieldType "GeospatialCoordinateEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("Scale or Timestep not set", e.getMessage()); } // Should pass settings.get("testfield").put("n", 33); settings.get("testfield").put("w", 3); settings.get("testfield").put("scale", 30); settings.get("testfield").put("timestep", 60); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(GeospatialCoordinateEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleLogEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "float", // fieldType "LogEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("One of n, radius, resolution must be specified for a LogEncoder", e.getMessage()); } // Should pass settings.get("testfield").put("w", 5); settings.get("testfield").put("resolution", 0.1); settings.get("testfield").put("minVal", 1.0); settings.get("testfield").put("maxVal", 10000.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(LogEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssemblePassThroughEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "int", // fieldType "PassThroughEncoder"); // encoderType // Should pass through MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(PassThroughEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleScalarEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "int", // fieldType "ScalarEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("W must be an odd number (to eliminate centering difficulty)", e.getMessage()); } // Should pass settings.get("testfield").put("w", 5); settings.get("testfield").put("resolution", 0.1); settings.get("testfield").put("minVal", 1.0); settings.get("testfield").put("maxVal", 10000.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(ScalarEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleSparsePassThroughEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "sarr", // fieldType "SparsePassThroughEncoder"); // encoderType // Should pass through MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(SparsePassThroughEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleRandomDistributedScalarEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "int", // fieldType "RandomDistributedScalarEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("W must be an odd positive integer (to eliminate centering difficulty)", e.getMessage()); } // Attempt build omitting needed settings (should fail) settings.get("testfield").put("w", 5); try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("Resolution must be a positive number", e.getMessage()); } // Should pass settings.get("testfield").put("n", 60); settings.get("testfield").put("resolution", 1.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(RandomDistributedScalarEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleDateEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "datetime", // fieldType "DateEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(DateEncoder.class, l.get(0).getEncoder().getClass()); Map<String, Object> entry = new HashMap<>(); entry.put("testfield", new DateTime()); me.encode(entry); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); } // Attempt build omitting needed settings (should fail) settings.get("testfield").put("w", 5); try { settings.get("testfield").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week settings.get("testfield").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(DateEncoder.class, l.get(0).getEncoder().getClass()); Map<String, Object> entry = new HashMap<>(); entry.put("testfield", new DateTime()); int[] encoding = me.encode(entry); assertTrue(encoding.length > 0); }catch(Exception e) { fail(); } } @Test public void testAssembleDeltaEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "float", // fieldType "DeltaEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("W must be an odd number (to eliminate centering difficulty)", e.getMessage()); } // Should pass settings.get("testfield").put("n", 100); settings.get("testfield").put("w", 21); settings.get("testfield").put("radius", 1.5); settings.get("testfield").put("resolution", 0.5); settings.get("testfield").put("minVal", 1.0); settings.get("testfield").put("maxVal", 8.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(DeltaEncoder.class, l.get(0).getEncoder().getClass()); } @Test public void testAssembleSDRPassThroughEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "sarr", // fieldType "SDRPassThroughEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(SDRPassThroughEncoder.class, l.get(0).getEncoder().getClass()); }catch(Exception e) { fail(); } } @Test public void testAssembleAdaptiveScalarEncoder() { Map<String, Map<String, Object>> settings = NetworkTestHarness.setupMap( null, // map 0, // n 0, // w 0, // min 0, // max 0, // radius 0, // resolution null, // periodic null, // clip Boolean.TRUE, // forced "testfield", // fieldName "int", // fieldType "ScalarEncoder"); // encoderType // Attempt build omitting needed settings (should fail) try { MultiEncoder me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); fail(); }catch(Exception e) { assertEquals(IllegalStateException.class, e.getClass()); assertEquals("W must be an odd number (to eliminate centering difficulty)", e.getMessage()); } // Should pass settings.get("testfield").put("w", 5); settings.get("testfield").put("resolution", 0.1); settings.get("testfield").put("minVal", 1.0); settings.get("testfield").put("maxVal", 10000.0); MultiEncoder me = null; try { me = MultiEncoder.builder().name("").build(); MultiEncoderAssembler.assemble(me, settings); }catch(Exception e) {e.printStackTrace(); fail(); } List<EncoderTuple> l = me.getEncoders(me); assertTrue(!l.isEmpty()); assertEquals(ScalarEncoder.class, l.get(0).getEncoder().getClass()); } }