/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.numenta.nupic.encoders.ScalarEncoder;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.Tuple;
public class MultiEncoderTest {
private MultiEncoder me;
private MultiEncoder.Builder builder;
private void setUp() {
builder = MultiEncoder.builder().name("");
}
private void initME() {
me = builder.build();
}
/**
* Test addition of encoders one-by-one.
*/
@Test
public void testAdaptiveScalarEncoder() {
setUp();
initME();
Encoder.Builder<?,?> ase = me.getBuilder("AdaptiveScalarEncoder");
assertNotNull(ase);
try {
me.getBuilder("BogusEncoder");
fail(); //Expect exception thrown here
} catch (Exception e) {
assertTrue(e instanceof IllegalArgumentException);
}
//runMixedTests(me);
}
/**
* Test addition of encoders one-by-one.
*/
@Test
public void testSerialAdditions() {
setUp();
initME();
ScalarEncoder dow = ScalarEncoder.builder()
.w(3)
.resolution(1)
.minVal(1)
.maxVal(8)
.periodic(true)
.name("day of week")
.forced(true)
.build();
me.addEncoder("dow", dow);
ScalarEncoder myval = ScalarEncoder.builder()
.w(5)
.resolution(1)
.minVal(1)
.maxVal(10)
.periodic(false)
.name("aux")
.forced(true)
.build();
me.addEncoder("myval", myval);
runScalarTests(me);
List<String> categoryList = new ArrayList<String>();
categoryList.add("run");
categoryList.add("pass");
categoryList.add("kick");
CategoryEncoder myCat = CategoryEncoder.builder()
.radius(2)
.w(3)
.categoryList(categoryList)
.forced(true)
.build();
me.addEncoder("myCat", myCat);
runMixedTests(me);
}
/**
* Test addition of encoders all at once.
*/
@Test
public void testMultipleAdditions() {
setUp();
initME();
Map<String, Map<String, Object>> fieldEncodings = new HashMap<String, Map<String, Object>>();
fieldEncodings.put("dow", new HashMap<String, Object>());
fieldEncodings.get("dow").put("encoderType", "ScalarEncoder");
fieldEncodings.get("dow").put("fieldName", "dow");
fieldEncodings.get("dow").put("w", 3);
fieldEncodings.get("dow").put("resolution", 1.);
fieldEncodings.get("dow").put("minVal", 1.);
fieldEncodings.get("dow").put("maxVal", 8.);
fieldEncodings.get("dow").put("periodic", true);
fieldEncodings.get("dow").put("name", "day of week");
fieldEncodings.get("dow").put("forced", true);
fieldEncodings.put("myval", new HashMap<String, Object>());
fieldEncodings.get("myval").put("encoderType", "ScalarEncoder");
fieldEncodings.get("myval").put("fieldName", "myval");
fieldEncodings.get("myval").put("w", 5);
fieldEncodings.get("myval").put("resolution", 1.);
fieldEncodings.get("myval").put("minVal", 1.);
fieldEncodings.get("myval").put("maxVal", 10.);
fieldEncodings.get("myval").put("periodic", false);
fieldEncodings.get("myval").put("name", "aux");
fieldEncodings.get("myval").put("forced", true);
me.addMultipleEncoders(fieldEncodings);
runScalarTests(me);
setUp();
initME();
List<String> categoryList = new ArrayList<String>();
categoryList.add("run");
categoryList.add("pass");
categoryList.add("kick");
fieldEncodings.put("myCat", new HashMap<String, Object>());
fieldEncodings.get("myCat").put("encoderType", "CategoryEncoder");
fieldEncodings.get("myCat").put("fieldName", "myCat");
fieldEncodings.get("myCat").put("w", 3);
fieldEncodings.get("myCat").put("radius", 2.);
fieldEncodings.get("myCat").put("categoryList", categoryList);
fieldEncodings.get("myCat").put("forced", true);
me.addMultipleEncoders(fieldEncodings);
runMixedTests(me);
}
@SuppressWarnings("unchecked")
public void runScalarTests(MultiEncoder me) {
// should be 7 bits wide
// use of forced=true is not recommended, but here for readability, see scalar.py
int[] expected = new int[]{0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1};
Map<String, Object> d = new HashMap<String, Object>();
d.put("dow", 3.);
d.put("myval", 10.);
int[] output = me.encode(d);
assertTrue(Arrays.equals(expected, output));
// Check decoding
Tuple decoded = me.decode(output, "");
Map<String, RangeList> fields = (HashMap<String, RangeList>) decoded.get(0);
assertEquals(fields.keySet().size(), 2);
MinMax minMax = fields.get("aux").getRange(0);
assertTrue(minMax.toString().equals(new MinMax(10.0, 10.0).toString()));
minMax = fields.get("day of week").getRange(0);
assertTrue(minMax.toString().equals(new MinMax(3.0, 3.0).toString()));
}
public void runMixedTests(MultiEncoder me) {
Map<String, Object> d = new HashMap<String, Object>();
d.put("dow", 4.);
d.put("myval", 6.);
d.put("myCat", "pass");
int[] output = me.encode(d);
List<Encoding> topDownOut = me.topDownCompute(output);
// When encoders are added one at a time, they're kept in the order they were added,
// but when they're added all at once, they're sorted by name, so we need to be careful
// here.
ScalarEncoder dow= null, myval = null;
CategoryEncoder myCat = null;
Encoding dowActual = null, myvalActual = null, myCatActual = null;
for (int i = 0; i < me.getEncoders(me).size(); i++) {
EncoderTuple t = me.getEncoders(me).get(i);
String name = t.getName();
if (name.equals("dow")) {
dow = (ScalarEncoder) t.getEncoder();
dowActual = topDownOut.get(i);
} else if (name.equals("myval")) {
myval = (ScalarEncoder) t.getEncoder();
myvalActual = topDownOut.get(i);
} else if (name.equals("myCat")) {
myCat = (CategoryEncoder) t.getEncoder();
myCatActual = topDownOut.get(i);
}
}
Encoding dowExpected = dow.topDownCompute(dow.encode(4.)).get(0);
Encoding myvalExpected = myval.topDownCompute(myval.encode(6.)).get(0);
Encoding myCatExpected = myCat.topDownCompute(myCat.encode("pass")).get(0);
assertTrue(dowActual.equals(dowExpected));
assertTrue(myvalActual.equals(myvalExpected));
assertTrue(myCatActual.equals(myCatExpected));
}
}