/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.junit.Test;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.util.MersenneTwister;
import org.numenta.nupic.util.Tuple;
public class ParametersTest {
private Parameters parameters;
@Test
public void testEquals() {
Parameters p1 = Parameters.empty();
Parameters p2 = Parameters.empty();
assertEquals(p1, p2);
// Positive Number
p1.set(KEY.POTENTIAL_PCT, 32.0);
p2.set(KEY.POTENTIAL_PCT, 32.0);
assertEquals(p1, p2);
// Negative Number
p1.set(KEY.POTENTIAL_PCT, 32.0);
p2.set(KEY.POTENTIAL_PCT, 32.2);
assertNotEquals(p1, p2);
p2.set(KEY.POTENTIAL_PCT, 32.0); // reset
// Positive int[]
p1.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
p2.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
assertEquals(p1, p2);
// Negative int[]
p1.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
p2.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2049 });
assertNotEquals(p1, p2);
p2.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2048 }); // reset
// Positive Field Encodings Map
Map<String, Map<String, Object>> map = getHotGymFieldEncodingMap();
p1.set(KEY.FIELD_ENCODING_MAP, map);
p2.set(KEY.FIELD_ENCODING_MAP, map);
assertEquals(p1, p2);
// Negative Field Encodings Map - vary N
Map<String, Map<String, Object>> map2 = getHotGymFieldEncodingMap_varyN();
p1.set(KEY.FIELD_ENCODING_MAP, map);
p2.set(KEY.FIELD_ENCODING_MAP, map2);
assertNotEquals(p1, p2);
// Negative Field Encodings Map - vary inner Tuple value
map2 = getHotGymFieldEncodingMap_varyDateFieldTupleValue();
p1.set(KEY.FIELD_ENCODING_MAP, map);
p2.set(KEY.FIELD_ENCODING_MAP, map2);
assertNotEquals(p1, p2);
// Negative Field Encodings Map - vary Date Field Key
map2 = getHotGymFieldEncodingMap_varyDateFieldKey();
p1.set(KEY.FIELD_ENCODING_MAP, map);
p2.set(KEY.FIELD_ENCODING_MAP, map2);
assertNotEquals(p1, p2);
// Re-assert if changed back that it passes
p1.set(KEY.FIELD_ENCODING_MAP, map2);
assertEquals(p1, p2);
}
@Test
public void testApply() {
DummyContainer dc = new DummyContainer();
Parameters params = Parameters.getAllDefaultParameters();
params.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
params.set(Parameters.KEY.POTENTIAL_PCT, 20.0);
params.set(Parameters.KEY.CELLS_PER_COLUMN, null);
params.apply(dc);
assertTrue(Arrays.equals(new int[] { 2048 }, dc.getColumnDimensions()));
assertEquals(20.0, dc.getPotentialPct(), 0);
}
@Test
public void testDefaultsAndUpdates() {
Parameters params = Parameters.getAllDefaultParameters();
assertEquals(params.get(Parameters.KEY.CELLS_PER_COLUMN), 32);
assertEquals(params.get(Parameters.KEY.SEED), 42);
assertEquals(true, ((Random)params.get(Parameters.KEY.RANDOM)).getClass().equals(MersenneTwister.class));
System.out.println("All Defaults:\n" + Parameters.getAllDefaultParameters());
System.out.println("Spatial Defaults:\n" + Parameters.getSpatialDefaultParameters());
System.out.println("Temporal Defaults:\n" + Parameters.getTemporalDefaultParameters());
parameters = Parameters.getSpatialDefaultParameters();
parameters.set(Parameters.KEY.INPUT_DIMENSIONS, new int[]{64, 64});
parameters.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[]{32, 32});
parameters.set(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 0.02*64*64);
System.out.println("Updated/Combined:\n" + parameters);
}
public static class DummyContainerBase {
private int[] columnDimensions;
public int[] getColumnDimensions() {
return columnDimensions;
}
public void setColumnDimensions(int[] columnDimensions) {
this.columnDimensions = columnDimensions;
}
}
public static class DummyContainer extends DummyContainerBase {
private double potentialPct = 0;
public double getPotentialPct() {
return potentialPct;
}
public void setPotentialPct(double potentialPct) {
this.potentialPct = potentialPct;
}
}
@Test
public void testUnion() {
Parameters params = Parameters.getAllDefaultParameters();
Parameters arg = Parameters.getAllDefaultParameters();
arg.set(KEY.CELLS_PER_COLUMN, 5);
assertTrue((int)params.get(KEY.CELLS_PER_COLUMN) != 5);
params.union(arg);
assertTrue((int)params.get(KEY.CELLS_PER_COLUMN) == 5);
}
@Test
public void testGetKeyByFieldName() {
KEY expected = Parameters.KEY.POTENTIAL_PCT;
assertEquals(expected, KEY.getKeyByFieldName("potentialPct"));
assertFalse(expected.equals(KEY.getKeyByFieldName("random")));
}
@Test
public void testGetMinMax() {
KEY synPermActInc = KEY.SYN_PERM_ACTIVE_INC;
assertEquals(0.0, synPermActInc.getMin());
assertEquals(1.0, synPermActInc.getMax());
}
@Test
public void testCheckRange() {
Parameters params = Parameters.getAllDefaultParameters();
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, 2.0);
fail();
}catch(Exception e) {
assertEquals(e.getClass(), IllegalArgumentException.class);
assertEquals("Can not set Parameters Property 'synPermActiveInc' because of value '2.0' not in range. Range[0.0-1.0]", e.getMessage());
}
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, -0.6);
fail();
}catch(Exception e) {
assertEquals(e.getClass(), IllegalArgumentException.class);
assertEquals("Can not set Parameters Property 'synPermActiveInc' because of value '-0.6' not in range. Range[0.0-1.0]", e.getMessage());
}
try {
KEY.SYN_PERM_ACTIVE_INC.checkRange(null);
fail();
}catch(Exception e) {
assertEquals(e.getClass(), IllegalArgumentException.class);
assertEquals("checkRange argument can not be null", e.getMessage());
}
// Test catch type mismatch
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, Boolean.TRUE);
fail();
}catch(Exception e) {
assertEquals(e.getClass(), IllegalArgumentException.class);
assertEquals("Can not set Parameters Property 'synPermActiveInc' because of type mismatch. The required type is class java.lang.Double", e.getMessage());
}
// Check values _AT_ the min / max (should pass)
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, 0.0);
assertEquals(0.0, (double)params.get(KEY.SYN_PERM_ACTIVE_INC), 0.0);
}catch(Exception e) {
fail();
}
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, 1.0);
assertEquals(1.0, (double)params.get(KEY.SYN_PERM_ACTIVE_INC), 0.0);
}catch(Exception e) {
fail();
}
// Positive test
try {
params.set(KEY.SYN_PERM_ACTIVE_INC, 0.8);
assertEquals(0.8, (double)params.get(KEY.SYN_PERM_ACTIVE_INC), 0.0);
}catch(Exception e) {
fail();
}
}
@Test
public void testSize() {
Parameters params = Parameters.getAllDefaultParameters();
assertEquals(48, params.size());
}
@Test
public void testKeys() {
Parameters params = Parameters.getAllDefaultParameters();
assertTrue(params.keys() != null && params.keys().size() == 48);
}
@Test
public void testClearParameter() {
Parameters params = Parameters.getAllDefaultParameters();
assertNotNull(params.get(KEY.SYN_PERM_ACTIVE_INC));
params.clearParameter(KEY.SYN_PERM_ACTIVE_INC);
assertNull(params.get(KEY.SYN_PERM_ACTIVE_INC));
}
@Test
public void testLogDiff() {
Parameters params = Parameters.getAllDefaultParameters();
assertNotNull(params.get(KEY.SYN_PERM_ACTIVE_INC));
Connections connections = new Connections();
params.apply(connections);
Parameters all = Parameters.getAllDefaultParameters();
all.set(KEY.SYN_PERM_ACTIVE_INC, 0.9);
boolean b = all.logDiff(connections);
assertTrue(b);
try {
all.logDiff(null);
fail();
}catch(Exception e) {
assertEquals(IllegalArgumentException.class, e.getClass());
assertEquals("cn Object is required and can not be null", e.getMessage());
}
}
@Test
public void testSetterMethods() {
Parameters params = Parameters.getAllDefaultParameters();
params.setCellsPerColumn(42);
assertEquals(42, params.get(KEY.CELLS_PER_COLUMN));
params.setActivationThreshold(42);
assertEquals(42, params.get(KEY.ACTIVATION_THRESHOLD));
params.setLearningRadius(42);
assertEquals(42, params.get(KEY.LEARNING_RADIUS));
params.setMinThreshold(42);
assertEquals(42, params.get(KEY.MIN_THRESHOLD));
params.setSeed(42);
assertEquals(42, params.get(KEY.SEED));
params.setInitialPermanence(0.82);
assertEquals(0.82, params.get(KEY.INITIAL_PERMANENCE));
params.setConnectedPermanence(0.82);
assertEquals(0.82, params.get(KEY.CONNECTED_PERMANENCE));
params.setPermanenceIncrement(0.11);
assertEquals(0.11, params.get(KEY.PERMANENCE_INCREMENT));
params.setPermanenceDecrement(0.11);
assertEquals(0.11, params.get(KEY.PERMANENCE_DECREMENT));
params.setMaxSegmentsPerCell(11);
assertEquals(11, params.get(KEY.MAX_SEGMENTS_PER_CELL));
params.setMaxSynapsesPerSegment(22);
assertEquals(22, params.get(KEY.MAX_SYNAPSES_PER_SEGMENT));
params.setMaxNewSynapseCount(32);
assertEquals(32, params.get(KEY.MAX_NEW_SYNAPSE_COUNT));
}
/**
* Returns the Hot Gym encoder setup.
* @return
*/
public static Map<String, Map<String, Object>> getHotGymFieldEncodingMap() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"timestamp", "datetime", "DateEncoder");
fieldEncodings = setupMap(
fieldEncodings,
25,
3,
0, 0, 0, 0.1, null, null, null,
"consumption", "float", "RandomDistributedScalarEncoder");
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_PATTERN.getFieldName(), "MM/dd/YY HH:mm");
return fieldEncodings;
}
/**
* Returns the Hot Gym encoder setup.
* @return
*/
public static Map<String, Map<String, Object>> getHotGymFieldEncodingMap_varyN() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
20, // n
0, // w
0, 0, 0, 0, null, null, null,
"timestamp", "datetime", "DateEncoder");
fieldEncodings = setupMap(
fieldEncodings,
25,
3,
0, 0, 0, 0.1, null, null, null,
"consumption", "float", "RandomDistributedScalarEncoder");
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_PATTERN.getFieldName(), "MM/dd/YY HH:mm");
return fieldEncodings;
}
/**
* Returns the Hot Gym encoder setup.
* @return
*/
public static Map<String, Map<String, Object>> getHotGymFieldEncodingMap_varyDateFieldTupleValue() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"timestamp", "datetime", "DateEncoder");
fieldEncodings = setupMap(
fieldEncodings,
25,
3,
0, 0, 0, 0.1, null, null, null,
"consumption", "float", "RandomDistributedScalarEncoder");
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 2.0)); // Day of week
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_PATTERN.getFieldName(), "MM/dd/YY HH:mm");
return fieldEncodings;
}
/**
* Returns the Hot Gym encoder setup.
* @return
*/
public static Map<String, Map<String, Object>> getHotGymFieldEncodingMap_varyDateFieldKey() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"timestamp", "datetime", "DateEncoder");
fieldEncodings = setupMap(
fieldEncodings,
25,
3,
0, 0, 0, 0.1, null, null, null,
"consumption", "float", "RandomDistributedScalarEncoder");
// fieldEncodings.get("timestamp").put(KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(5, 4.0)); // Time of day
fieldEncodings.get("timestamp").put(KEY.DATEFIELD_PATTERN.getFieldName(), "MM/dd/YY HH:mm");
return fieldEncodings;
}
public static Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n, int w, double min, double max, double radius, double resolution, Boolean periodic,
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) {
if(map == null) {
map = new HashMap<String, Map<String, Object>>();
}
Map<String, Object> inner = null;
if((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<String, Object>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
inner.put("radius", radius);
inner.put("resolution", resolution);
if(periodic != null) inner.put("periodic", periodic);
if(clip != null) inner.put("clipInput", clip);
if(forced != null) inner.put("forced", forced);
if(fieldName != null) inner.put("fieldName", fieldName);
if(fieldType != null) inner.put("fieldType", fieldType);
if(encoderType != null) inner.put("encoderType", encoderType);
return map;
}
}