/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.junit.Test;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.SpatialPooler.InvalidSPParamValueException;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Pool;
import org.numenta.nupic.util.AbstractSparseBinaryMatrix;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.SparseBinaryMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.UniversalRandom;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
public class SpatialPoolerTest {
private Parameters parameters;
private SpatialPooler sp;
private Connections mem;
public void setupParameters() {
parameters = Parameters.getAllDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 5 });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 5 });
parameters.set(KEY.POTENTIAL_RADIUS, 5);
parameters.set(KEY.POTENTIAL_PCT, 0.5);
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
parameters.set(KEY.SYN_PERM_INACTIVE_DEC, 0.01);
parameters.set(KEY.SYN_PERM_ACTIVE_INC, 0.1);
parameters.set(KEY.SYN_PERM_CONNECTED, 0.1);
parameters.set(KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1);
parameters.set(KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1);
parameters.set(KEY.DUTY_CYCLE_PERIOD, 10);
parameters.set(KEY.MAX_BOOST, 10.0);
parameters.setRandom(new UniversalRandom(42));
}
public void setupDefaultParameters() {
parameters = Parameters.getAllDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 32, 32 });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 64, 64 });
parameters.set(KEY.POTENTIAL_RADIUS, 16);
parameters.set(KEY.POTENTIAL_PCT, 0.5);
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 10.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
parameters.set(KEY.SYN_PERM_INACTIVE_DEC, 0.008);
parameters.set(KEY.SYN_PERM_ACTIVE_INC, 0.05);
parameters.set(KEY.SYN_PERM_CONNECTED, 0.10);
parameters.set(KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.001);
parameters.set(KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.001);
parameters.set(KEY.DUTY_CYCLE_PERIOD, 1000);
parameters.set(KEY.MAX_BOOST, 10.0);
parameters.set(KEY.SEED, 42);
parameters.setRandom(new UniversalRandom(42));
}
private void initSP() {
sp = new SpatialPooler();
mem = new Connections();
parameters.apply(mem);
sp.init(mem);
}
@Test
public void confirmSPConstruction() {
setupParameters();
initSP();
assertEquals(5, mem.getInputDimensions()[0]);
assertEquals(5, mem.getColumnDimensions()[0]);
assertEquals(5, mem.getPotentialRadius());
assertEquals(0.5, mem.getPotentialPct(), 0);
assertEquals(false, mem.getGlobalInhibition());
assertEquals(-1.0, mem.getLocalAreaDensity(), 0);
assertEquals(3, mem.getNumActiveColumnsPerInhArea(), 0);
assertEquals(1, mem.getStimulusThreshold(), 1);
assertEquals(0.01, mem.getSynPermInactiveDec(), 0);
assertEquals(0.1, mem.getSynPermActiveInc(), 0);
assertEquals(0.1, mem.getSynPermConnected(), 0);
assertEquals(0.1, mem.getMinPctOverlapDutyCycles(), 0);
assertEquals(0.1, mem.getMinPctActiveDutyCycles(), 0);
assertEquals(10, mem.getDutyCyclePeriod(), 0);
assertEquals(10.0, mem.getMaxBoost(), 0);
assertEquals(42, mem.getSeed());
assertEquals(5, mem.getNumInputs());
assertEquals(5, mem.getNumColumns());
}
/**
* Checks that feeding in the same input vector leads to polarized
* permanence values: either zeros or ones, but no fractions
*/
@Test
public void testCompute1() {
setupParameters();
parameters.setInputDimensions(new int[] { 9 });
parameters.setColumnDimensions(new int[] { 5 });
parameters.setPotentialRadius(5);
//This is 0.3 in Python version due to use of dense
// permanence instead of sparse (as it should be)
parameters.setPotentialPct(0.5);
parameters.setGlobalInhibition(false);
parameters.setLocalAreaDensity(-1.0);
parameters.setNumActiveColumnsPerInhArea(3);
parameters.setStimulusThreshold(1);
parameters.setSynPermInactiveDec(0.01);
parameters.setSynPermActiveInc(0.1);
parameters.setMinPctOverlapDutyCycles(0.1);
parameters.setMinPctActiveDutyCycles(0.1);
parameters.setDutyCyclePeriod(10);
parameters.setMaxBoost(10);
parameters.setSynPermTrimThreshold(0);
//This is 0.5 in Python version due to use of dense
// permanence instead of sparse (as it should be)
parameters.setPotentialPct(1);
parameters.setSynPermConnected(0.1);
initSP();
SpatialPooler mock = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public int[] inhibitColumns(Connections c, double[] overlaps) {
return new int[] { 0, 1, 2, 3, 4 };
}
};
int[] inputVector = new int[] { 1, 0, 1, 0, 1, 0, 0, 1, 1 };
int[] activeArray = new int[] { 0, 0, 0, 0, 0 };
for(int i = 0;i < 20;i++) {
mock.compute(mem, inputVector, activeArray, true);
}
for(int i = 0;i < mem.getNumColumns();i++) {
int[] permanences = ArrayUtils.toIntArray(mem.getPotentialPools().get(i).getDensePermanences(mem));
assertTrue(Arrays.equals(inputVector, permanences));
}
}
/**
* Checks that columns only change the permanence values for
* inputs that are within their potential pool
*/
@Test
public void testCompute2() {
setupParameters();
parameters.setInputDimensions(new int[] { 10 });
parameters.setColumnDimensions(new int[] { 5 });
parameters.setPotentialRadius(3);
parameters.setPotentialPct(0.3);
parameters.setGlobalInhibition(false);
parameters.setLocalAreaDensity(-1.0);
parameters.setNumActiveColumnsPerInhArea(3);
parameters.setStimulusThreshold(1);
parameters.setSynPermInactiveDec(0.01);
parameters.setSynPermActiveInc(0.1);
parameters.setMinPctOverlapDutyCycles(0.1);
parameters.setMinPctActiveDutyCycles(0.1);
parameters.setDutyCyclePeriod(10);
parameters.setMaxBoost(10);
parameters.setSynPermConnected(0.1);
initSP();
SpatialPooler mock = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public int[] inhibitColumns(Connections c, double[] overlaps) {
return new int[] { 0, 1, 2, 3, 4 };
}
};
int[] inputVector = new int[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
int[] activeArray = new int[] { 0, 0, 0, 0, 0 };
for(int i = 0;i < 20;i++) {
mock.compute(mem, inputVector, activeArray, true);
}
for(int i = 0;i < mem.getNumColumns();i++) {
int[] permanences = ArrayUtils.toIntArray(mem.getPotentialPools().get(i).getDensePermanences(mem));
int[] potential = (int[])mem.getConnectedCounts().getSlice(i);
assertTrue(Arrays.equals(permanences, potential));
}
}
/**
* When stimulusThreshold is 0, allow columns without any overlap to become
* active. This test focuses on the global inhibition code path.
*/
@Test
public void testZeroOverlap_NoStimulusThreshold_GlobalInhibition() {
int inputSize = 10;
int nColumns = 20;
parameters = Parameters.getSpatialDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
parameters.set(KEY.POTENTIAL_RADIUS, 10);
parameters.set(KEY.GLOBAL_INHIBITION, true);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
parameters.set(KEY.RANDOM, new UniversalRandom(42));
parameters.set(KEY.SEED, 42);
SpatialPooler sp = new SpatialPooler();
Connections cn = new Connections();
parameters.apply(cn);
sp.init(cn);
int[] activeArray = new int[nColumns];
sp.compute(cn, new int[inputSize], activeArray, true);
assertEquals(3, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
}
/**
* When stimulusThreshold is > 0, don't allow columns without any overlap to
* become active. This test focuses on the global inhibition code path.
*/
@Test
public void testZeroOverlap_StimulusThreshold_GlobalInhibition() {
int inputSize = 10;
int nColumns = 20;
parameters = Parameters.getSpatialDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
parameters.set(KEY.POTENTIAL_RADIUS, 10);
parameters.set(KEY.GLOBAL_INHIBITION, true);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(KEY.RANDOM, new UniversalRandom(42));
parameters.set(KEY.SEED, 42);
SpatialPooler sp = new SpatialPooler();
Connections cn = new Connections();
parameters.apply(cn);
sp.init(cn);
int[] activeArray = new int[nColumns];
sp.compute(cn, new int[inputSize], activeArray, true);
assertEquals(0, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
}
@Test
public void testZeroOverlap_NoStimulusThreshold_LocalInhibition() {
int inputSize = 10;
int nColumns = 20;
parameters = Parameters.getSpatialDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
parameters.set(KEY.POTENTIAL_RADIUS, 5);
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 1.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
parameters.set(KEY.RANDOM, new UniversalRandom(42));
parameters.set(KEY.SEED, 42);
SpatialPooler sp = new SpatialPooler();
Connections cn = new Connections();
parameters.apply(cn);
sp.init(cn);
// This exact number of active columns is determined by the inhibition
// radius, which changes based on the random synapses (i.e. weird math).
// Force it to a known number.
cn.setInhibitionRadius(2);
int[] activeArray = new int[nColumns];
sp.compute(cn, new int[inputSize], activeArray, true);
assertEquals(6, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
}
/**
* When stimulusThreshold is > 0, don't allow columns without any overlap to
* become active. This test focuses on the local inhibition code path.
*/
@Test
public void testZeroOverlap_StimulusThreshold_LocalInhibition() {
int inputSize = 10;
int nColumns = 20;
parameters = Parameters.getSpatialDefaultParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
parameters.set(KEY.POTENTIAL_RADIUS, 10);
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(KEY.RANDOM, new UniversalRandom(42));
parameters.set(KEY.SEED, 42);
SpatialPooler sp = new SpatialPooler();
Connections cn = new Connections();
parameters.apply(cn);
sp.init(cn);
int[] activeArray = new int[nColumns];
sp.compute(cn, new int[inputSize], activeArray, true);
assertEquals(0, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
}
@Test
public void testOverlapsOutput() {
parameters = Parameters.getSpatialDefaultParameters();
parameters.setColumnDimensions(new int[] { 3 });
parameters.setInputDimensions(new int[] { 5 });
parameters.setPotentialRadius(5);
parameters.setNumActiveColumnsPerInhArea(5);
parameters.setGlobalInhibition(true);
parameters.setSynPermActiveInc(0.1);
parameters.setSynPermInactiveDec(0.1);
parameters.setSeed(42);
parameters.setRandom(new UniversalRandom(42));
SpatialPooler sp = new SpatialPooler();
Connections cn = new Connections();
parameters.apply(cn);
sp.init(cn);
cn.setBoostFactors(new double[] { 2.0, 2.0, 2.0 });
int[] inputVector = { 1, 1, 1, 1, 1 };
int[] activeArray = { 0, 0, 0 };
int[] expOutput = { 2, 1, 0 };
sp.compute(cn, inputVector, activeArray, true);
double[] boostedOverlaps = cn.getBoostedOverlaps();
int[] overlaps = cn.getOverlaps();
for(int i = 0;i < cn.getNumColumns();i++) {
assertEquals(expOutput[i], overlaps[i]);
assertEquals(expOutput[i] * 2, boostedOverlaps[i], 0.01);
}
}
/**
* Given a specific input and initialization params the SP should return this
* exact output.
*
* Previously output varied between platforms (OSX/Linux etc) == (in Python)
*/
@Test
public void testExactOutput() {
setupParameters();
parameters.setInputDimensions(new int[] { 1, 188 });
parameters.setColumnDimensions(new int[] { 2048, 1 });
parameters.setPotentialRadius(94);
parameters.setPotentialPct(0.5);
parameters.setGlobalInhibition(true);
parameters.setLocalAreaDensity(-1.0);
parameters.setNumActiveColumnsPerInhArea(40);
parameters.setStimulusThreshold(0);
parameters.setSynPermInactiveDec(0.01);
parameters.setSynPermActiveInc(0.1);
parameters.setMinPctOverlapDutyCycles(0.001);
parameters.setMinPctActiveDutyCycles(0.001);
parameters.setDutyCyclePeriod(1000);
parameters.setMaxBoost(10);
initSP();
int[] inputVector = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0
};
int[] activeArray = new int[2048];
sp.compute(mem, inputVector, activeArray, true);
int[] real = ArrayUtils.where(activeArray, new Condition.Adapter<Object>() {
public boolean eval(int n) {
return n > 0;
}
});
int[] expected = new int[] {
74, 203, 237, 270, 288, 317, 479, 529, 530, 622, 659, 720, 757, 790, 924, 956, 1033,
1041, 1112, 1332, 1386, 1430, 1500, 1517, 1578, 1584, 1651, 1664, 1717, 1735, 1747,
1748, 1775, 1779, 1788, 1813, 1888, 1911, 1938, 1958 };
assertTrue(Arrays.equals(expected, real));
}
@Test
public void testStripNeverLearned() {
setupParameters();
parameters.setColumnDimensions(new int[] { 6 });
parameters.setInputDimensions(new int[] { 9 });
initSP();
mem.updateActiveDutyCycles(new double[] { 0.5, 0.1, 0, 0.2, 0.4, 0 });
int[] activeColumns = new int[] { 0, 1, 2, 4 };
int[] stripped = sp.stripUnlearnedColumns(mem, activeColumns);
int[] trueStripped = new int[] { 0, 1, 4 };
assertTrue(Arrays.equals(trueStripped, stripped));
mem.updateActiveDutyCycles(new double[] { 0.9, 0, 0, 0, 0.4, 0.3 });
activeColumns = ArrayUtils.range(0, 6);
stripped = sp.stripUnlearnedColumns(mem, activeColumns);
trueStripped = new int[] { 0, 4, 5 };
assertTrue(Arrays.equals(trueStripped, stripped));
mem.updateActiveDutyCycles(new double[] { 0, 0, 0, 0, 0, 0 });
activeColumns = ArrayUtils.range(0, 6);
stripped = sp.stripUnlearnedColumns(mem, activeColumns);
trueStripped = new int[] {};
assertTrue(Arrays.equals(trueStripped, stripped));
mem.updateActiveDutyCycles(new double[] { 1, 1, 1, 1, 1, 1 });
activeColumns = ArrayUtils.range(0, 6);
stripped = sp.stripUnlearnedColumns(mem, activeColumns);
trueStripped = ArrayUtils.range(0, 6);
assertTrue(Arrays.equals(trueStripped, stripped));
}
@Test
public void testMapColumn() {
// Test 1D
setupParameters();
parameters.setColumnDimensions(new int[] { 4 });
parameters.setInputDimensions(new int[] { 12 });
initSP();
assertEquals(1, sp.mapColumn(mem, 0));
assertEquals(4, sp.mapColumn(mem, 1));
assertEquals(7, sp.mapColumn(mem, 2));
assertEquals(10, sp.mapColumn(mem, 3));
// Test 1D with same dimension of columns and inputs
setupParameters();
parameters.setColumnDimensions(new int[] { 4 });
parameters.setInputDimensions(new int[] { 4 });
initSP();
assertEquals(0, sp.mapColumn(mem, 0));
assertEquals(1, sp.mapColumn(mem, 1));
assertEquals(2, sp.mapColumn(mem, 2));
assertEquals(3, sp.mapColumn(mem, 3));
// Test 1D with dimensions of length 1
setupParameters();
parameters.setColumnDimensions(new int[] { 1 });
parameters.setInputDimensions(new int[] { 1 });
initSP();
assertEquals(0, sp.mapColumn(mem, 0));
// Test 2D
setupParameters();
parameters.setColumnDimensions(new int[] { 12, 4 });
parameters.setInputDimensions(new int[] { 36, 12 });
initSP();
assertEquals(13, sp.mapColumn(mem, 0));
assertEquals(49, sp.mapColumn(mem, 4));
assertEquals(52, sp.mapColumn(mem, 5));
assertEquals(58, sp.mapColumn(mem, 7));
assertEquals(418, sp.mapColumn(mem, 47));
// Test 2D with some input dimensions smaller than column dimensions.
setupParameters();
parameters.setColumnDimensions(new int[] { 4, 4 });
parameters.setInputDimensions(new int[] { 3, 5 });
initSP();
assertEquals(0, sp.mapColumn(mem, 0));
assertEquals(4, sp.mapColumn(mem, 3));
assertEquals(14, sp.mapColumn(mem, 15));
}
@Test
public void testMapPotential1D() {
setupParameters();
parameters.setInputDimensions(new int[] { 12 });
parameters.setColumnDimensions(new int[] { 4 });
parameters.setPotentialRadius(2);
parameters.setPotentialPct(1);
parameters.set(KEY.WRAP_AROUND, false);
initSP();
assertEquals(12, mem.getInputDimensions()[0]);
assertEquals(4, mem.getColumnDimensions()[0]);
assertEquals(2, mem.getPotentialRadius());
// Test without wrapAround and potentialPct = 1
int[] expected = new int[] { 0, 1, 2, 3 };
int[] mask = sp.mapPotential(mem, 0, false);
assertTrue(Arrays.equals(expected, mask));
expected = new int[] { 5, 6, 7, 8, 9 };
mask = sp.mapPotential(mem, 2, false);
assertTrue(Arrays.equals(expected, mask));
// Test with wrapAround and potentialPct = 1
mem.setWrapAround(true);
expected = new int[] { 0, 1, 2, 3, 11 };
mask = sp.mapPotential(mem, 0, true);
assertTrue(Arrays.equals(expected, mask));
expected = new int[] { 0, 8, 9, 10, 11 };
mask = sp.mapPotential(mem, 3, true);
assertTrue(Arrays.equals(expected, mask));
// Test with wrapAround and potentialPct < 1
parameters.setPotentialPct(0.5);
parameters.set(KEY.WRAP_AROUND, true);
initSP();
int[] supersetMask = new int[] { 0, 1, 2, 3, 11 };
mask = sp.mapPotential(mem, 0, true);
assertEquals(mask.length, 3);
TIntArrayList unionList = new TIntArrayList(supersetMask);
unionList.addAll(mask);
int[] unionMask = ArrayUtils.unique(unionList.toArray());
assertTrue(Arrays.equals(unionMask, supersetMask));
}
@Test
public void testMapPotential2D() {
setupParameters();
parameters.setInputDimensions(new int[] { 6, 12 });
parameters.setColumnDimensions(new int[] { 2, 4 });
parameters.setPotentialRadius(1);
parameters.setPotentialPct(1);
initSP();
//Test without wrapAround
int[] mask = sp.mapPotential(mem, 0, false);
TIntHashSet trueIndices = new TIntHashSet(new int[] { 0, 1, 2, 12, 13, 14, 24, 25, 26 });
TIntHashSet maskSet = new TIntHashSet(mask);
assertTrue(trueIndices.equals(maskSet));
trueIndices.clear();
maskSet.clear();
trueIndices.addAll(new int[] { 6, 7, 8, 18, 19, 20, 30, 31, 32 });
mask = sp.mapPotential(mem, 2, false);
maskSet.addAll(mask);
assertTrue(trueIndices.equals(maskSet));
//Test with wrapAround
trueIndices.clear();
maskSet.clear();
parameters.setPotentialRadius(2);
initSP();
trueIndices.addAll(
new int[] { 0, 1, 2, 3, 11,
12, 13, 14, 15, 23,
24, 25, 26, 27, 35,
36, 37, 38, 39, 47,
60, 61, 62, 63, 71 });
mask = sp.mapPotential(mem, 0, true);
maskSet.addAll(mask);
assertTrue(trueIndices.equals(maskSet));
trueIndices.clear();
maskSet.clear();
trueIndices.addAll(
new int[] { 0, 8, 9, 10, 11,
12, 20, 21, 22, 23,
24, 32, 33, 34, 35,
36, 44, 45, 46, 47,
60, 68, 69, 70, 71 });
mask = sp.mapPotential(mem, 3, true);
maskSet.addAll(mask);
assertTrue(trueIndices.equals(maskSet));
}
@Test
public void testMapPotential1Column1Input() {
setupParameters();
parameters.setInputDimensions(new int[] { 1 });
parameters.setColumnDimensions(new int[] { 1 });
parameters.setPotentialRadius(2);
parameters.setPotentialPct(1);
parameters.set(KEY.WRAP_AROUND, false);
initSP();
//Test without wrapAround and potentialPct = 1
int[] expectedMask = new int[] { 0 };
int[] mask = sp.mapPotential(mem, 0, false);
TIntHashSet trueIndices = new TIntHashSet(expectedMask);
TIntHashSet maskSet = new TIntHashSet(mask);
// The *position* of the one "on" bit expected.
// Python version returns [1] which is the on bit in the zero'th position
assertTrue(trueIndices.equals(maskSet));
}
//////////////////////////////////////////////////////////////
/**
* Local test apparatus for {@link #testInhibitColumns()}
*/
boolean globalCalled = false;
boolean localCalled = false;
double _density = 0;
public void reset() {
this.globalCalled = false;
this.localCalled = false;
this._density = 0;
}
public void setGlobalCalled(boolean b) {
this.globalCalled = b;
}
public void setLocalCalled(boolean b) {
this.localCalled = b;
}
//////////////////////////////////////////////////////////////
@Test
public void testInhibitColumns() {
setupParameters();
parameters.setColumnDimensions(new int[] { 5 });
parameters.setInhibitionRadius(10);
initSP();
//Mocks to test which method gets called
SpatialPooler inhibitColumnsGlobal = new SpatialPooler() {
private static final long serialVersionUID = 1L;
@Override public int[] inhibitColumnsGlobal(Connections c, double[] overlap, double density) {
setGlobalCalled(true);
_density = density;
return new int[] { 1 };
}
};
SpatialPooler inhibitColumnsLocal = new SpatialPooler() {
private static final long serialVersionUID = 1L;
@Override public int[] inhibitColumnsLocal(Connections c, double[] overlap, double density) {
setLocalCalled(true);
_density = density;
return new int[] { 2 };
}
};
double[] overlaps = ArrayUtils.sample(mem.getNumColumns(), mem.getRandom());
mem.setNumActiveColumnsPerInhArea(5);
mem.setLocalAreaDensity(0.1);
mem.setGlobalInhibition(true);
mem.setInhibitionRadius(5);
double trueDensity = mem.getLocalAreaDensity();
inhibitColumnsGlobal.inhibitColumns(mem, overlaps);
assertTrue(globalCalled);
assertTrue(!localCalled);
assertEquals(trueDensity, _density, .01d);
//////
reset();
mem.setColumnDimensions(new int[] { 50, 10 });
//Internally calculated during init, to overwrite we put after init
mem.setGlobalInhibition(false);
mem.setInhibitionRadius(7);
double[] tieBreaker = new double[500];
Arrays.fill(tieBreaker, 0);
mem.setTieBreaker(tieBreaker);
overlaps = ArrayUtils.sample(mem.getNumColumns(), mem.getRandom());
inhibitColumnsLocal.inhibitColumns(mem, overlaps);
trueDensity = mem.getLocalAreaDensity();
assertTrue(!globalCalled);
assertTrue(localCalled);
assertEquals(trueDensity, _density, .01d);
//////
reset();
parameters.setInputDimensions(new int[] { 100, 10 });
parameters.setColumnDimensions(new int[] { 100, 10 });
parameters.setGlobalInhibition(false);
parameters.setLocalAreaDensity(-1);
parameters.setNumActiveColumnsPerInhArea(3);
initSP();
//Internally calculated during init, to overwrite we put after init
mem.setInhibitionRadius(4);
tieBreaker = new double[1000];
Arrays.fill(tieBreaker, 0);
mem.setTieBreaker(tieBreaker);
overlaps = ArrayUtils.sample(mem.getNumColumns(), mem.getRandom());
inhibitColumnsLocal.inhibitColumns(mem, overlaps);
trueDensity = 3.0 / 81.0;
assertTrue(!globalCalled);
assertTrue(localCalled);
assertEquals(trueDensity, _density, .01d);
//////
reset();
mem.setNumActiveColumnsPerInhArea(7);
//Internally calculated during init, to overwrite we put after init
mem.setInhibitionRadius(1);
tieBreaker = new double[1000];
Arrays.fill(tieBreaker, 0);
mem.setTieBreaker(tieBreaker);
overlaps = ArrayUtils.sample(mem.getNumColumns(), mem.getRandom());
inhibitColumnsLocal.inhibitColumns(mem, overlaps);
trueDensity = 0.5;
assertTrue(!globalCalled);
assertTrue(localCalled);
assertEquals(trueDensity, _density, .01d);
}
@Test
public void testUpdateBoostFactors() {
setupParameters();
parameters.setInputDimensions(new int[] { 5/*Don't care*/ });
parameters.setColumnDimensions(new int[] { 5 });
parameters.setMaxBoost(10.0);
parameters.setRandom(new UniversalRandom(42));
initSP();
mem.setNumColumns(6);
double[] minActiveDutyCycles = new double[6];
Arrays.fill(minActiveDutyCycles, 0.000001D);
mem.setMinActiveDutyCycles(minActiveDutyCycles);
double[] activeDutyCycles = new double[] { 0.1, 0.3, 0.02, 0.04, 0.7, 0.12 };
mem.setActiveDutyCycles(activeDutyCycles);
double[] trueBoostFactors = new double[] { 1, 1, 1, 1, 1, 1 };
sp.updateBoostFactors(mem);
double[] boostFactors = mem.getBoostFactors();
for(int i = 0;i < boostFactors.length;i++) {
assertEquals(trueBoostFactors[i], boostFactors[i], 0.1D);
}
////////////////
minActiveDutyCycles = new double[] { 0.1, 0.3, 0.02, 0.04, 0.7, 0.12 };
mem.setMinActiveDutyCycles(minActiveDutyCycles);
Arrays.fill(mem.getBoostFactors(), 0);
sp.updateBoostFactors(mem);
boostFactors = mem.getBoostFactors();
for(int i = 0;i < boostFactors.length;i++) {
assertEquals(trueBoostFactors[i], boostFactors[i], 0.1D);
}
////////////////
minActiveDutyCycles = new double[] { 0.1, 0.2, 0.02, 0.03, 0.7, 0.12 };
mem.setMinActiveDutyCycles(minActiveDutyCycles);
activeDutyCycles = new double[] { 0.01, 0.02, 0.002, 0.003, 0.07, 0.012 };
mem.setActiveDutyCycles(activeDutyCycles);
trueBoostFactors = new double[] { 9.1, 9.1, 9.1, 9.1, 9.1, 9.1 };
sp.updateBoostFactors(mem);
boostFactors = mem.getBoostFactors();
for(int i = 0;i < boostFactors.length;i++) {
assertEquals(trueBoostFactors[i], boostFactors[i], 0.1D);
}
////////////////
minActiveDutyCycles = new double[] { 0.1, 0.2, 0.02, 0.03, 0.7, 0.12 };
mem.setMinActiveDutyCycles(minActiveDutyCycles);
Arrays.fill(activeDutyCycles, 0);
mem.setActiveDutyCycles(activeDutyCycles);
Arrays.fill(trueBoostFactors, 10.0);
sp.updateBoostFactors(mem);
boostFactors = mem.getBoostFactors();
for(int i = 0;i < boostFactors.length;i++) {
assertEquals(trueBoostFactors[i], boostFactors[i], 0.1D);
}
}
@Test
public void testUpdateInhibitionRadius() {
setupParameters();
initSP();
//Test global inhibition case
mem.setGlobalInhibition(true);
mem.setColumnDimensions(new int[] { 57, 31, 2 });
sp.updateInhibitionRadius(mem);
assertEquals(57, mem.getInhibitionRadius());
////////////
// ((3 * 4) - 1) / 2 => round up
SpatialPooler mock = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public double avgConnectedSpanForColumnND(Connections c, int columnIndex) {
return 3;
}
public double avgColumnsPerInput(Connections c) {
return 4;
}
};
mem.setGlobalInhibition(false);
sp = mock;
sp.updateInhibitionRadius(mem);
assertEquals(6, mem.getInhibitionRadius());
//////////////
//Test clipping at 1.0
mock = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public double avgConnectedSpanForColumnND(Connections c, int columnIndex) {
return 0.5;
}
public double avgColumnsPerInput(Connections c) {
return 1.2;
}
};
mem.setGlobalInhibition(false);
sp = mock;
sp.updateInhibitionRadius(mem);
assertEquals(1, mem.getInhibitionRadius());
/////////////
//Test rounding up
mock = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public double avgConnectedSpanForColumnND(Connections c, int columnIndex) {
return 2.4;
}
public double avgColumnsPerInput(Connections c) {
return 2;
}
};
mem.setGlobalInhibition(false);
sp = mock;
//((2 * 2.4) - 1) / 2.0 => round up
sp.updateInhibitionRadius(mem);
assertEquals(2, mem.getInhibitionRadius());
}
@Test
public void testAvgColumnsPerInput() {
setupParameters();
initSP();
mem.setColumnDimensions(new int[] { 2, 2, 2, 2 });
mem.setInputDimensions(new int[] { 4, 4, 4, 4 });
assertEquals(0.5, sp.avgColumnsPerInput(mem), 0);
mem.setColumnDimensions(new int[] { 2, 2, 2, 2 });
mem.setInputDimensions(new int[] { 7, 5, 1, 3 });
double trueAvgColumnPerInput = (2.0/7 + 2.0/5 + 2.0/1 + 2/3.0) / 4.0d;
assertEquals(trueAvgColumnPerInput, sp.avgColumnsPerInput(mem), 0);
mem.setColumnDimensions(new int[] { 3, 3 });
mem.setInputDimensions(new int[] { 3, 3 });
trueAvgColumnPerInput = 1;
assertEquals(trueAvgColumnPerInput, sp.avgColumnsPerInput(mem), 0);
mem.setColumnDimensions(new int[] { 25 });
mem.setInputDimensions(new int[] { 5 });
trueAvgColumnPerInput = 5;
assertEquals(trueAvgColumnPerInput, sp.avgColumnsPerInput(mem), 0);
mem.setColumnDimensions(new int[] { 3, 3, 3, 5, 5, 6, 6 });
mem.setInputDimensions(new int[] { 3, 3, 3, 5, 5, 6, 6 });
trueAvgColumnPerInput = 1;
assertEquals(trueAvgColumnPerInput, sp.avgColumnsPerInput(mem), 0);
mem.setColumnDimensions(new int[] { 3, 6, 9, 12 });
mem.setInputDimensions(new int[] { 3, 3, 3 , 3 });
trueAvgColumnPerInput = 2.5;
assertEquals(trueAvgColumnPerInput, sp.avgColumnsPerInput(mem), 0);
}
@Test
public void testAvgConnectedSpanForColumnND() {
sp = new SpatialPooler();
mem = new Connections();
int[] inputDimensions = new int[] { 4, 4, 2, 5 };
mem.setInputDimensions(inputDimensions);
mem.setColumnDimensions(new int[] { 5 });
sp.initMatrices(mem);
TIntArrayList connected = new TIntArrayList();
connected.add(mem.getInputMatrix().computeIndex(new int[] { 1, 0, 1, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 1, 0, 1, 1 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 2, 1, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 0, 1, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 1, 0, 1, 3 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 2, 2, 1, 0 }, false));
connected.sort(0, connected.size());
//[ 45 46 48 105 125 145]
//mem.getConnectedSynapses().set(0, connected.toArray());
mem.getPotentialPools().set(0, new Pool(6));
mem.getColumn(0).setProximalConnectedSynapsesForTest(mem, connected.toArray());
connected.clear();
connected.add(mem.getInputMatrix().computeIndex(new int[] { 2, 0, 1, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 2, 0, 0, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 0, 0, 0 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 0, 1, 0 }, false));
connected.sort(0, connected.size());
//[ 80 85 120 125]
//mem.getConnectedSynapses().set(1, connected.toArray());
mem.getPotentialPools().set(1, new Pool(4));
mem.getColumn(1).setProximalConnectedSynapsesForTest(mem, connected.toArray());
connected.clear();
connected.add(mem.getInputMatrix().computeIndex(new int[] { 0, 0, 1, 4 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 0, 0, 0, 3 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 0, 0, 0, 1 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 1, 0, 0, 2 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 0, 0, 1, 1 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 3, 1, 1 }, false));
connected.sort(0, connected.size());
//[ 1 3 6 9 42 156]
//mem.getConnectedSynapses().set(2, connected.toArray());
mem.getPotentialPools().set(2, new Pool(4));
mem.getColumn(2).setProximalConnectedSynapsesForTest(mem, connected.toArray());
connected.clear();
connected.add(mem.getInputMatrix().computeIndex(new int[] { 3, 3, 1, 4 }, false));
connected.add(mem.getInputMatrix().computeIndex(new int[] { 0, 0, 0, 0 }, false));
connected.sort(0, connected.size());
//[ 0 159]
//mem.getConnectedSynapses().set(3, connected.toArray());
mem.getPotentialPools().set(3, new Pool(4));
mem.getColumn(3).setProximalConnectedSynapsesForTest(mem, connected.toArray());
//[]
connected.clear();
mem.getPotentialPools().set(4, new Pool(4));
mem.getColumn(4).setProximalConnectedSynapsesForTest(mem, connected.toArray());
double[] trueAvgConnectedSpan = new double[] { 11.0/4d, 6.0/4d, 14.0/4d, 15.0/4d, 0d };
for(int i = 0;i < mem.getNumColumns();i++) {
double connectedSpan = sp.avgConnectedSpanForColumnND(mem, i);
assertEquals(trueAvgConnectedSpan[i], connectedSpan, 0);
}
}
@Test
public void testBumpUpWeakColumns() {
setupParameters();
parameters.setInputDimensions(new int[] { 8 });
parameters.setColumnDimensions(new int[] { 5 });
initSP();
mem.setSynPermBelowStimulusInc(0.01);
mem.setSynPermTrimThreshold(0.05);
mem.setOverlapDutyCycles(new double[] { 0, 0.009, 0.1, 0.001, 0.002 });
mem.setMinOverlapDutyCycles(new double[] { .01, .01, .01, .01, .01 });
int[][] potentialPools = new int[][] {
{ 1, 1, 1, 1, 0, 0, 0, 0 },
{ 1, 0, 0, 0, 1, 1, 0, 1 },
{ 0, 0, 1, 0, 1, 1, 1, 0 },
{ 1, 1, 1, 0, 0, 0, 1, 0 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }
};
double[][] permanences = new double[][] {
{ 0.200, 0.120, 0.090, 0.040, 0.000, 0.000, 0.000, 0.000 },
{ 0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450 },
{ 0.000, 0.000, 0.014, 0.000, 0.032, 0.044, 0.110, 0.000 },
{ 0.041, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000 },
{ 0.100, 0.738, 0.045, 0.002, 0.050, 0.008, 0.208, 0.034 }
};
double[][] truePermanences = new double[][] {
{ 0.210, 0.130, 0.100, 0.000, 0.000, 0.000, 0.000, 0.000 },
{ 0.160, 0.000, 0.000, 0.000, 0.190, 0.130, 0.000, 0.460 },
{ 0.000, 0.000, 0.014, 0.000, 0.032, 0.044, 0.110, 0.000 },
{ 0.051, 0.000, 0.000, 0.000, 0.000, 0.000, 0.188, 0.000 },
{ 0.110, 0.748, 0.055, 0.000, 0.060, 0.000, 0.218, 0.000 }
};
Condition<?> cond = new Condition.Adapter<Integer>() {
public boolean eval(int n) {
return n == 1;
}
};
for(int i = 0;i < mem.getNumColumns();i++) {
int[] indexes = ArrayUtils.where(potentialPools[i], cond);
mem.getColumn(i).setProximalConnectedSynapsesForTest(mem, indexes);
mem.getColumn(i).setProximalPermanences(mem, permanences[i]);
}
//Execute method being tested
sp.bumpUpWeakColumns(mem);
for(int i = 0;i < mem.getNumColumns();i++) {
double[] perms = mem.getPotentialPools().get(i).getDensePermanences(mem);
for(int j = 0;j < truePermanences[i].length;j++) {
assertEquals(truePermanences[i][j], perms[j], 0.01);
}
}
}
@Test
public void testUpdateMinDutyCycleLocal() {
setupDefaultParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 8 });
parameters.set(KEY.WRAP_AROUND, false);
initSP();
mem.setInhibitionRadius(1);
mem.setOverlapDutyCycles(new double[] { 0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001 });
mem.setActiveDutyCycles(new double[] { 0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12 });
mem.setMinPctActiveDutyCycles(0.1);
mem.setMinPctOverlapDutyCycles(0.2);
sp.updateMinDutyCyclesLocal(mem);
double[] resultMinActiveDutyCycles = mem.getMinActiveDutyCycles();
double[] expected0 = { 0.09, 0.09, 0.07, 0.07, 0.07, 0.01, 0.012, 0.012 };
IntStream.range(0, expected0.length)
.forEach(i -> assertEquals(expected0[i], resultMinActiveDutyCycles[i], 0.01));
double[] resultMinOverlapDutyCycles = mem.getMinOverlapDutyCycles();
double[] expected1 = new double[] { 0.14, 0.14, 0.1, 0.156, 0.156, 0.156, 0.11, 0.02 };
IntStream.range(0, expected1.length)
.forEach(i -> assertEquals(expected1[i], resultMinOverlapDutyCycles[i], 0.01));
// wrapAround = true
setupDefaultParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 8 });
parameters.set(KEY.WRAP_AROUND, true);
initSP();
mem.setInhibitionRadius(1);
mem.setOverlapDutyCycles(new double[] { 0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001 });
mem.setActiveDutyCycles(new double[] { 0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12 });
mem.setMinPctActiveDutyCycles(0.1);
mem.setMinPctOverlapDutyCycles(0.2);
sp.updateMinDutyCyclesLocal(mem);
double[] resultMinActiveDutyCycles2 = mem.getMinActiveDutyCycles();
double[] expected2 = { 0.09, 0.09, 0.07, 0.07, 0.07, 0.01, 0.012, 0.09 };
IntStream.range(0, expected2.length)
.forEach(i -> assertEquals(expected2[i], resultMinActiveDutyCycles2[i], 0.01));
double[] resultMinOverlapDutyCycles2 = mem.getMinOverlapDutyCycles();
double[] expected3 = new double[] { 0.14, 0.14, 0.1, 0.156, 0.156, 0.156, 0.11, 0.14 };
IntStream.range(0, expected3.length)
.forEach(i -> assertEquals(expected3[i], resultMinOverlapDutyCycles2[i], 0.01));
}
@Test
public void testUpdateMinDutyCycleGlobal() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 5 });
initSP();
mem.setMinPctOverlapDutyCycles(0.01);
mem.setMinPctActiveDutyCycles(0.02);
mem.setOverlapDutyCycles(new double[] { 0.06, 1, 3, 6, 0.5 });
mem.setActiveDutyCycles(new double[] { 0.6, 0.07, 0.5, 0.4, 0.3 });
sp.updateMinDutyCyclesGlobal(mem);
double[] trueMinActiveDutyCycles = new double[mem.getNumColumns()];
Arrays.fill(trueMinActiveDutyCycles, 0.02*0.6);
double[] trueMinOverlapDutyCycles = new double[mem.getNumColumns()];
Arrays.fill(trueMinOverlapDutyCycles, 0.01*6);
for(int i = 0;i < mem.getNumColumns();i++) {
// System.out.println(i + ") " + trueMinOverlapDutyCycles[i] + " - " + mem.getMinOverlapDutyCycles()[i]);
// System.out.println(i + ") " + trueMinActiveDutyCycles[i] + " - " + mem.getMinActiveDutyCycles()[i]);
assertEquals(trueMinOverlapDutyCycles[i], mem.getMinOverlapDutyCycles()[i], 0.01);
assertEquals(trueMinActiveDutyCycles[i], mem.getMinActiveDutyCycles()[i], 0.01);
}
mem.setMinPctOverlapDutyCycles(0.015);
mem.setMinPctActiveDutyCycles(0.03);
mem.setOverlapDutyCycles(new double[] { 0.86, 2.4, 0.03, 1.6, 1.5 });
mem.setActiveDutyCycles(new double[] { 0.16, 0.007, 0.15, 0.54, 0.13 });
sp.updateMinDutyCyclesGlobal(mem);
Arrays.fill(trueMinOverlapDutyCycles, 0.015*2.4);
for(int i = 0;i < mem.getNumColumns();i++) {
// System.out.println(i + ") " + trueMinOverlapDutyCycles[i] + " - " + mem.getMinOverlapDutyCycles()[i]);
// System.out.println(i + ") " + trueMinActiveDutyCycles[i] + " - " + mem.getMinActiveDutyCycles()[i]);
assertEquals(trueMinOverlapDutyCycles[i], mem.getMinOverlapDutyCycles()[i], 0.01);
}
mem.setMinPctOverlapDutyCycles(0.015);
mem.setMinPctActiveDutyCycles(0.03);
mem.setOverlapDutyCycles(new double[5]);
mem.setActiveDutyCycles(new double[5]);
sp.updateMinDutyCyclesGlobal(mem);
Arrays.fill(trueMinOverlapDutyCycles, 0);
Arrays.fill(trueMinActiveDutyCycles, 0);
for(int i = 0;i < mem.getNumColumns();i++) {
// System.out.println(i + ") " + trueMinOverlapDutyCycles[i] + " - " + mem.getMinOverlapDutyCycles()[i]);
// System.out.println(i + ") " + trueMinActiveDutyCycles[i] + " - " + mem.getMinActiveDutyCycles()[i]);
assertEquals(trueMinActiveDutyCycles[i], mem.getMinActiveDutyCycles()[i], 0.01);
assertEquals(trueMinOverlapDutyCycles[i], mem.getMinOverlapDutyCycles()[i], 0.01);
}
}
@Test
public void testIsUpdateRound() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 5 });
initSP();
mem.setUpdatePeriod(50);
mem.setIterationNum(1);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(39);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(50);
assertTrue(sp.isUpdateRound(mem));
mem.setIterationNum(1009);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(1250);
assertTrue(sp.isUpdateRound(mem));
mem.setUpdatePeriod(125);
mem.setIterationNum(0);
assertTrue(sp.isUpdateRound(mem));
mem.setIterationNum(200);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(249);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(1330);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(1249);
assertFalse(sp.isUpdateRound(mem));
mem.setIterationNum(1375);
assertTrue(sp.isUpdateRound(mem));
}
@Test
public void testAdaptSynapses() {
setupParameters();
parameters.setInputDimensions(new int[] { 8 });
parameters.setColumnDimensions(new int[] { 4 });
parameters.setSynPermInactiveDec(0.01);
parameters.setSynPermActiveInc(0.1);
initSP();
mem.setSynPermTrimThreshold(0.05);
int[][] potentialPools = new int[][] {
{ 1, 1, 1, 1, 0, 0, 0, 0 },
{ 1, 0, 0, 0, 1, 1, 0, 1 },
{ 0, 0, 1, 0, 0, 0, 1, 0 },
{ 1, 0, 0, 0, 0, 0, 1, 0 }
};
double[][] permanences = new double[][] {
{ 0.200, 0.120, 0.090, 0.040, 0.000, 0.000, 0.000, 0.000 },
{ 0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450 },
{ 0.000, 0.000, 0.014, 0.000, 0.000, 0.000, 0.110, 0.000 },
{ 0.040, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000 }
};
double[][] truePermanences = new double[][] {
{ 0.300, 0.110, 0.080, 0.140, 0.000, 0.000, 0.000, 0.000 },
// Inc Dec Dec Inc - - - -
{ 0.250, 0.000, 0.000, 0.000, 0.280, 0.110, 0.000, 0.440 },
// Inc - - - Inc Dec - Dec
{ 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.210, 0.000 },
// - - Trim - - - Inc -
{ 0.040, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000 }
// - - - - - - - - // Only cols 0,1,2 are active
// (see 'activeColumns' below)
};
Condition<?> cond = new Condition.Adapter<Integer>() {
public boolean eval(int n) {
return n == 1;
}
};
for(int i = 0;i < mem.getNumColumns();i++) {
int[] indexes = ArrayUtils.where(potentialPools[i], cond);
mem.getColumn(i).setProximalConnectedSynapsesForTest(mem, indexes);
mem.getColumn(i).setProximalPermanences(mem, permanences[i]);
}
int[] inputVector = new int[] { 1, 0, 0, 1, 1, 0, 1, 0 };
int[] activeColumns = new int[] { 0, 1, 2 };
sp.adaptSynapses(mem, inputVector, activeColumns);
for(int i = 0;i < mem.getNumColumns();i++) {
double[] perms = mem.getPotentialPools().get(i).getDensePermanences(mem);
for(int j = 0;j < truePermanences[i].length;j++) {
assertEquals(truePermanences[i][j], perms[j], 0.01);
}
}
//////////////////////////////
potentialPools = new int[][] {
{ 1, 1, 1, 0, 0, 0, 0, 0 },
{ 0, 1, 1, 1, 0, 0, 0, 0 },
{ 0, 0, 1, 1, 1, 0, 0, 0 },
{ 1, 0, 0, 0, 0, 0, 1, 0 }
};
permanences = new double[][] {
{ 0.200, 0.120, 0.090, 0.000, 0.000, 0.000, 0.000, 0.000 },
{ 0.000, 0.017, 0.232, 0.400, 0.180, 0.120, 0.000, 0.450 },
{ 0.000, 0.000, 0.014, 0.051, 0.730, 0.000, 0.000, 0.000 },
{ 0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000 }
};
truePermanences = new double[][] {
{ 0.300, 0.110, 0.080, 0.000, 0.000, 0.000, 0.000, 0.000 },
{ 0.000, 0.000, 0.222, 0.500, 0.000, 0.000, 0.000, 0.000 },
{ 0.000, 0.000, 0.000, 0.151, 0.830, 0.000, 0.000, 0.000 },
{ 0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000 }
};
for(int i = 0;i < mem.getNumColumns();i++) {
int[] indexes = ArrayUtils.where(potentialPools[i], cond);
mem.getColumn(i).setProximalConnectedSynapsesForTest(mem, indexes);
mem.getColumn(i).setProximalPermanences(mem, permanences[i]);
}
sp.adaptSynapses(mem, inputVector, activeColumns);
for(int i = 0;i < mem.getNumColumns();i++) {
double[] perms = mem.getPotentialPools().get(i).getDensePermanences(mem);
for(int j = 0;j < truePermanences[i].length;j++) {
assertEquals(truePermanences[i][j], perms[j], 0.01);
}
}
}
@Test
public void testRaisePermanenceThreshold() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 5 });
parameters.setSynPermConnected(0.1);
parameters.setStimulusThreshold(3);
parameters.setSynPermBelowStimulusInc(0.01);
//The following parameter is not set to "1" in the Python version
//This is necessary to reproduce the test conditions of having as
//many pool members as Input Bits, which would never happen under
//normal circumstances because we want to enforce sparsity
parameters.setPotentialPct(1);
initSP();
//We set the values on the Connections permanences here just for illustration
SparseObjectMatrix<double[]> objMatrix = new SparseObjectMatrix<double[]>(new int[] { 5, 5 });
objMatrix.set(0, new double[] { 0.0, 0.11, 0.095, 0.092, 0.01 });
objMatrix.set(1, new double[] { 0.12, 0.15, 0.02, 0.12, 0.09 });
objMatrix.set(2, new double[] { 0.51, 0.081, 0.025, 0.089, 0.31 });
objMatrix.set(3, new double[] { 0.18, 0.0601, 0.11, 0.011, 0.03 });
objMatrix.set(4, new double[] { 0.011, 0.011, 0.011, 0.011, 0.011 });
mem.setProximalPermanences(objMatrix);
// mem.setConnectedSynapses(new SparseObjectMatrix<int[]>(new int[] { 5, 5 }));
// SparseObjectMatrix<int[]> syns = mem.getConnectedSynapses();
// syns.set(0, new int[] { 0, 1, 0, 0, 0 });
// syns.set(1, new int[] { 1, 1, 0, 1, 0 });
// syns.set(2, new int[] { 1, 0, 0, 0, 1 });
// syns.set(3, new int[] { 1, 0, 1, 0, 0 });
// syns.set(4, new int[] { 0, 0, 0, 0, 0 });
mem.setConnectedCounts(new int[] { 1, 3, 2, 2, 0 });
double[][] truePermanences = new double[][] {
{0.01, 0.12, 0.105, 0.102, 0.02}, // incremented once
{0.12, 0.15, 0.02, 0.12, 0.09}, // no change
{0.53, 0.101, 0.045, 0.109, 0.33}, // increment twice
{0.22, 0.1001, 0.15, 0.051, 0.07}, // increment four times
{0.101, 0.101, 0.101, 0.101, 0.101}}; // increment 9 times
//FORGOT TO SET PERMANENCES ABOVE - DON'T USE mem.setPermanences()
int[] indices = mem.getMemory().getSparseIndices();
for(int i = 0;i < mem.getNumColumns();i++) {
double[] perm = mem.getPotentialPools().get(i).getSparsePermanences();
sp.raisePermanenceToThreshold(mem, perm, indices);
for(int j = 0;j < perm.length;j++) {
assertEquals(truePermanences[i][j], perm[j], 0.001);
}
}
}
@Test
public void testUpdatePermanencesForColumn() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 5 });
parameters.setSynPermTrimThreshold(0.05);
//The following parameter is not set to "1" in the Python version
//This is necessary to reproduce the test conditions of having as
//many pool members as Input Bits, which would never happen under
//normal circumstances because we want to enforce sparsity
parameters.setPotentialPct(1);
initSP();
double[][] permanences = new double[][] {
{-0.10, 0.500, 0.400, 0.010, 0.020},
{0.300, 0.010, 0.020, 0.120, 0.090},
{0.070, 0.050, 1.030, 0.190, 0.060},
{0.180, 0.090, 0.110, 0.010, 0.030},
{0.200, 0.101, 0.050, -0.09, 1.100}};
int[][] trueConnectedSynapses = new int[][] {
{0, 1, 1, 0, 0},
{1, 0, 0, 1, 0},
{0, 0, 1, 1, 0},
{1, 0, 1, 0, 0},
{1, 1, 0, 0, 1}};
int[][] connectedDense = new int[][] {
{ 1, 2 },
{ 0, 3 },
{ 2, 3 },
{ 0, 2 },
{ 0, 1, 4 }
};
int[] trueConnectedCounts = new int[] {2, 2, 2, 2, 3};
for(int i = 0;i < mem.getNumColumns();i++) {
mem.getColumn(i).setProximalPermanences(mem, permanences[i]);
sp.updatePermanencesForColumn(mem, permanences[i], mem.getColumn(i), connectedDense[i], true);
int[] dense = mem.getColumn(i).getProximalDendrite().getConnectedSynapsesDense(mem);
assertEquals(Arrays.toString(trueConnectedSynapses[i]), Arrays.toString(dense));
}
assertEquals(Arrays.toString(trueConnectedCounts), Arrays.toString(mem.getConnectedCounts().getTrueCounts()));
}
@Test
public void testCalculateOverlap() {
setupDefaultParameters();
parameters.setInputDimensions(new int[] { 10 });
parameters.setColumnDimensions(new int[] { 5 });
initSP();
int[] dimensions = new int[] { 5, 10 };
int[][] connectedSynapses = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 0, 0, 1, 1}};
AbstractSparseBinaryMatrix sm = new SparseBinaryMatrix(dimensions);
for(int i = 0;i < sm.getDimensions()[0];i++) {
for(int j = 0;j < sm.getDimensions()[1];j++) {
sm.set(connectedSynapses[i][j], i, j);
}
}
mem.setConnectedMatrix(sm);
for(int i = 0;i < 5;i++) {
for(int j = 0;j < 10;j++) {
assertEquals(connectedSynapses[i][j], sm.getIntValue(i, j));
}
}
int[] inputVector = new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
int[] overlaps = sp.calculateOverlap(mem, inputVector);
int[] trueOverlaps = new int[5];
double[] overlapsPct = sp.calculateOverlapPct(mem, overlaps);
double[] trueOverlapsPct = new double[5];
assertTrue(Arrays.equals(trueOverlaps, overlaps));
assertTrue(Arrays.equals(trueOverlapsPct, overlapsPct));
/////////////
connectedSynapses = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 0, 0, 1, 1}};
sm = new SparseBinaryMatrix(dimensions);
for(int i = 0;i < sm.getDimensions()[0];i++) {
for(int j = 0;j < sm.getDimensions()[1];j++) {
sm.set(connectedSynapses[i][j], i, j);
}
}
mem.setConnectedMatrix(sm);
for(int i = 0;i < 5;i++) {
for(int j = 0;j < 10;j++) {
assertEquals(connectedSynapses[i][j], sm.getIntValue(i, j));
}
}
inputVector = new int[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
overlaps = sp.calculateOverlap(mem, inputVector);
trueOverlaps = new int[] { 10, 8, 6, 4, 2 };
overlapsPct = sp.calculateOverlapPct(mem, overlaps);
trueOverlapsPct = new double[] { 1, 1, 1, 1, 1 };
assertTrue(Arrays.equals(trueOverlaps, overlaps));
assertTrue(Arrays.equals(trueOverlapsPct, overlapsPct));
//////////////////
connectedSynapses = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 1, 1, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1},
{0, 0, 0, 0, 0, 0, 0, 0, 1, 1}};
sm = new SparseBinaryMatrix(dimensions);
for(int i = 0;i < sm.getDimensions()[0];i++) {
for(int j = 0;j < sm.getDimensions()[1];j++) {
sm.set(connectedSynapses[i][j], i, j);
}
}
mem.setConnectedMatrix(sm);
for(int i = 0;i < 5;i++) {
for(int j = 0;j < 10;j++) {
assertEquals(connectedSynapses[i][j], sm.getIntValue(i, j));
}
}
inputVector = new int[10];
inputVector[9] = 1;
overlaps = sp.calculateOverlap(mem, inputVector);
trueOverlaps = new int[] { 1, 1, 1, 1, 1 };
overlapsPct = sp.calculateOverlapPct(mem, overlaps);
trueOverlapsPct = new double[] { 0.1, 0.125, 1.0/6, 0.25, 0.5 };
assertTrue(Arrays.equals(trueOverlaps, overlaps));
assertTrue(Arrays.equals(trueOverlapsPct, overlapsPct));
///////////////////
connectedSynapses = new int[][] {
{1, 0, 0, 0, 0, 1, 0, 0, 0, 0},
{0, 1, 0, 0, 0, 0, 1, 0, 0, 0},
{0, 0, 1, 0, 0, 0, 0, 1, 0, 0},
{0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 0, 0, 1, 0, 0, 0, 0, 1}};
sm = new SparseBinaryMatrix(dimensions);
for(int i = 0;i < sm.getDimensions()[0];i++) {
for(int j = 0;j < sm.getDimensions()[1];j++) {
sm.set(connectedSynapses[i][j], i, j);
}
}
mem.setConnectedMatrix(sm);
for(int i = 0;i < 5;i++) {
for(int j = 0;j < 10;j++) {
assertEquals(connectedSynapses[i][j], sm.getIntValue(i, j));
}
}
inputVector = new int[] { 1, 0, 1, 0, 1, 0, 1, 0, 1, 0 };
overlaps = sp.calculateOverlap(mem, inputVector);
trueOverlaps = new int[] { 1, 1, 1, 1, 1 };
overlapsPct = sp.calculateOverlapPct(mem, overlaps);
trueOverlapsPct = new double[] { 0.5, 0.5, 0.5, 0.5, 0.5 };
assertTrue(Arrays.equals(trueOverlaps, overlaps));
assertTrue(Arrays.equals(trueOverlapsPct, overlapsPct));
}
/**
* test initial permanence generation. ensure that
* a correct amount of synapses are initialized in
* a connected state, with permanence values drawn from
* the correct ranges
*/
@Test
public void testInitPermanence1() {
setupParameters();
sp = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public void raisePermanenceToThreshold(Connections c, double[] perm, int[] maskPotential) {
//Mock out
}
};
mem = new Connections();
parameters.apply(mem);
sp.init(mem);
mem.setNumInputs(10);
mem.setPotentialRadius(2);
double connectedPct = 1;
int[] mask = new int[] { 0, 1, 2, 8, 9 };
double[] perm = sp.initPermanence(mem, mask, 0, connectedPct);
int numcon = ArrayUtils.valueGreaterCount(mem.getSynPermConnected(), perm);
assertEquals(5, numcon, 0);
connectedPct = 0;
perm = sp.initPermanence(mem, mask, 0, connectedPct);
numcon = ArrayUtils.valueGreaterCount(mem.getSynPermConnected(), perm);
assertEquals(0, numcon, 0);
connectedPct = 0.5;
mem.setPotentialRadius(100);
mem.setNumInputs(100);
mask = new int[100];
for(int i = 0;i < 100;i++) mask[i] = i;
final double[] perma = sp.initPermanence(mem, mask, 0, connectedPct);
numcon = ArrayUtils.valueGreaterOrEqualCount(mem.getSynPermConnected(), perma);
assertTrue(numcon > 0);
assertTrue(numcon < mem.getNumInputs());
final double minThresh = 0.0;
final double maxThresh = mem.getSynPermMax();
double[] results = ArrayUtils.retainLogicalAnd(perma, new Condition[] {
new Condition.Adapter<Object>() {
public boolean eval(double d) {
return d >= minThresh;
}
},
new Condition.Adapter<Object>() {
public boolean eval(double d) {
return d < maxThresh;
}
}
});
assertTrue(results.length > 0);
}
/**
* Test initial permanence generation. ensure that permanence values
* are only assigned to bits within a column's potential pool.
*/
@Test
public void testInitPermanence2() {
setupParameters();
sp = new SpatialPooler() {
private static final long serialVersionUID = 1L;
public void raisePermanenceToThreshold(Connections c, double[] perm, int[] maskPotential) {
//Mock out
}
};
mem = new Connections();
parameters.apply(mem);
sp.init(mem);
mem.setNumInputs(10);
double connectedPct = 1;
int[] mask = new int[] { 0, 1 };
double[] perm = sp.initPermanence(mem, mask, 0, connectedPct);
int[] trueConnected = new int[] { 0, 1 };
Condition<?> cond = new Condition.Adapter<Object>() {
public boolean eval(double d) {
return d > 0;
}
};
assertTrue(Arrays.equals(trueConnected, ArrayUtils.where(perm, cond)));
connectedPct = 1;
mask = new int[] { 4, 5, 6 };
perm = sp.initPermanence(mem, mask, 0, connectedPct);
trueConnected = new int[] { 4, 5, 6 };
assertTrue(Arrays.equals(trueConnected, ArrayUtils.where(perm, cond)));
connectedPct = 1;
mask = new int[] { 8, 9 };
perm = sp.initPermanence(mem, mask, 0, connectedPct);
trueConnected = new int[] { 8, 9 };
assertTrue(Arrays.equals(trueConnected, ArrayUtils.where(perm, cond)));
connectedPct = 1;
mask = new int[] { 0, 1, 2, 3, 4, 5, 6, 8, 9 };
perm = sp.initPermanence(mem, mask, 0, connectedPct);
trueConnected = new int[] { 0, 1, 2, 3, 4, 5, 6, 8, 9 };
assertTrue(Arrays.equals(trueConnected, ArrayUtils.where(perm, cond)));
}
/**
* Tests that duty cycles are updated properly according
* to the mathematical formula. also check the effects of
* supplying a maxPeriod to the function.
*/
@Test
public void testUpdateDutyCycleHelper() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 5 });
initSP();
double[] dc = new double[5];
Arrays.fill(dc, 1000.0);
double[] newvals = new double[5];
int period = 1000;
double[] newDc = sp.updateDutyCyclesHelper(mem, dc, newvals, period);
double[] trueNewDc = new double[] { 999, 999, 999, 999, 999 };
assertTrue(Arrays.equals(trueNewDc, newDc));
dc = new double[5];
Arrays.fill(dc, 1000.0);
newvals = new double[5];
Arrays.fill(newvals, 1000);
period = 1000;
newDc = sp.updateDutyCyclesHelper(mem, dc, newvals, period);
trueNewDc = Arrays.copyOf(dc, 5);
assertTrue(Arrays.equals(trueNewDc, newDc));
dc = new double[5];
Arrays.fill(dc, 1000.0);
newvals = new double[] { 2000, 4000, 5000, 6000, 7000 };
period = 1000;
newDc = sp.updateDutyCyclesHelper(mem, dc, newvals, period);
trueNewDc = new double[] { 1001, 1003, 1004, 1005, 1006 };
assertTrue(Arrays.equals(trueNewDc, newDc));
dc = new double[] { 1000, 800, 600, 400, 2000 };
newvals = new double[5];
period = 2;
newDc = sp.updateDutyCyclesHelper(mem, dc, newvals, period);
trueNewDc = new double[] { 500, 400, 300, 200, 1000 };
assertTrue(Arrays.equals(trueNewDc, newDc));
}
@Test
public void testInhibitColumnsGlobal() {
setupParameters();
parameters.setColumnDimensions(new int[] { 10 });
initSP();
//Internally calculated during init, to overwrite we put after init
parameters.setInhibitionRadius(2);
double density = 0.3;
double[] overlaps = new double[] { 1, 2, 1, 4, 8, 3, 12, 5, 4, 1 };
int[] active = sp.inhibitColumnsGlobal(mem, overlaps, density);
int[] trueActive = new int[] { 4, 6, 7 };
Arrays.sort(active);
assertTrue(Arrays.equals(trueActive, active));
density = 0.5;
mem.setNumColumns(10);
overlaps = IntStream.range(0, 10).mapToDouble(i -> i).toArray();
active = sp.inhibitColumnsGlobal(mem, overlaps, density);
trueActive = IntStream.range(5, 10).toArray();
assertTrue(Arrays.equals(trueActive, active));
}
@Test
public void testInhibitColumnsLocal() {
setupParameters();
parameters.setInputDimensions(new int[] { 5 });
parameters.setColumnDimensions(new int[] { 10 });
initSP();
//Internally calculated during init, to overwrite we put after init
mem.setInhibitionRadius(2);
double density = 0.5;
double[] overlaps = new double[] { 1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7 };
// L W W L L W W L W W (wrapAround=true)
// L W W L L W W L L W (wrapAround=false)
mem.setWrapAround(true);
int[] trueActive = new int[] {1, 2, 5, 6, 8, 9};
int[] active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
mem.setWrapAround(false);
trueActive = new int[] {1, 2, 5, 6, 9};
active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
density = 0.5;
mem.setInhibitionRadius(3);
overlaps = new double[] { 1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7 };
// L W W L W W W L L W (wrapAround=true)
// L W W L W W W L L L (wrapAround=false)
mem.setWrapAround(true);
trueActive = new int[] { 1, 2, 4, 5, 6, 9 };
active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
mem.setWrapAround(false);
trueActive = new int[] { 1, 2, 4, 5, 6, 9 };
active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
// Test add to winners
density = 0.3333;
mem.setInhibitionRadius(3);
overlaps = new double[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
// W W L L W W L L L L (wrapAround=true)
// W W L L W W L L W L (wrapAround=false)
mem.setWrapAround(true);
trueActive = new int[] { 0, 1, 4, 5 };
active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
mem.setWrapAround(false);
trueActive = new int[] { 0, 1, 4, 5, 8 };
active = sp.inhibitColumnsLocal(mem, overlaps, density);
assertTrue(Arrays.equals(trueActive, active));
}
// /**
// * As coded in the Python test
// */
// @Test
// public void testGetNeighborsND() {
// //This setup isn't relevant to this test
// setupParameters();
// parameters.setInputDimensions(new int[] { 9, 5 });
// parameters.setColumnDimensions(new int[] { 5, 5 });
// initSP();
// ////////////////////// Test not part of Python port /////////////////////
// int[] result = sp.getNeighborsND(mem, 2, mem.getInputMatrix(), 3, true).toArray();
// int[] expected = new int[] {
// 0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
// 13, 14, 15, 16, 17, 18, 19, 30, 31, 32, 33,
// 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44
// };
// for(int i = 0;i < result.length;i++) {
// assertEquals(expected[i], result[i]);
// }
// /////////////////////////////////////////////////////////////////////////
// setupParameters();
// int[] dimensions = new int[] { 5, 7, 2 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// int radius = 1;
// int x = 1;
// int y = 3;
// int z = 2;
// int columnIndex = mem.getInputMatrix().computeIndex(new int[] { z, y, x });
// int[] neighbors = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// String expect = "[18, 19, 20, 21, 22, 23, 32, 33, 34, 36, 37, 46, 47, 48, 49, 50, 51]";
// assertEquals(expect, ArrayUtils.print1DArray(neighbors));
//
// /////////////////////////////////////////
// setupParameters();
// dimensions = new int[] { 5, 7, 9 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// radius = 3;
// x = 0;
// y = 0;
// z = 3;
// columnIndex = mem.getInputMatrix().computeIndex(new int[] { z, y, x });
// neighbors = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// expect = "[0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17, 18, 19, 20, 21, 24, 25, 26, "
// + "27, 28, 29, 30, 33, 34, 35, 36, 37, 38, 39, 42, 43, 44, 45, 46, 47, 48, 51, "
// + "52, 53, 54, 55, 56, 57, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 73, 74, "
// + "75, 78, 79, 80, 81, 82, 83, 84, 87, 88, 89, 90, 91, 92, 93, 96, 97, 98, 99, "
// + "100, 101, 102, 105, 106, 107, 108, 109, 110, 111, 114, 115, 116, 117, 118, 119, "
// + "120, 123, 124, 125, 126, 127, 128, 129, 132, 133, 134, 135, 136, 137, 138, 141, "
// + "142, 143, 144, 145, 146, 147, 150, 151, 152, 153, 154, 155, 156, 159, 160, 161, "
// + "162, 163, 164, 165, 168, 169, 170, 171, 172, 173, 174, 177, 178, 179, 180, 181, "
// + "182, 183, 186, 187, 188, 190, 191, 192, 195, 196, 197, 198, 199, 200, 201, 204, "
// + "205, 206, 207, 208, 209, 210, 213, 214, 215, 216, 217, 218, 219, 222, 223, 224, "
// + "225, 226, 227, 228, 231, 232, 233, 234, 235, 236, 237, 240, 241, 242, 243, 244, "
// + "245, 246, 249, 250, 251, 252, 253, 254, 255, 258, 259, 260, 261, 262, 263, 264, "
// + "267, 268, 269, 270, 271, 272, 273, 276, 277, 278, 279, 280, 281, 282, 285, 286, "
// + "287, 288, 289, 290, 291, 294, 295, 296, 297, 298, 299, 300, 303, 304, 305, 306, "
// + "307, 308, 309, 312, 313, 314]";
// assertEquals(expect, ArrayUtils.print1DArray(neighbors));
//
// /////////////////////////////////////////
// setupParameters();
// dimensions = new int[] { 5, 10, 7, 6 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
//
// radius = 4;
// int w = 2;
// x = 5;
// y = 6;
// z = 2;
// columnIndex = mem.getInputMatrix().computeIndex(new int[] { z, y, x, w });
// neighbors = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// TIntHashSet trueNeighbors = new TIntHashSet();
// for(int i = -radius;i <= radius;i++) {
// for(int j = -radius;j <= radius;j++) {
// for(int k = -radius;k <= radius;k++) {
// for(int m = -radius;m <= radius;m++) {
// int zprime = (int)ArrayUtils.positiveRemainder((z + i), dimensions[0]);
// int yprime = (int)ArrayUtils.positiveRemainder((y + j), dimensions[1]);
// int xprime = (int)ArrayUtils.positiveRemainder((x + k), dimensions[2]);
// int wprime = (int)ArrayUtils.positiveRemainder((w + m), dimensions[3]);
// trueNeighbors.add(mem.getInputMatrix().computeIndex(new int[] { zprime, yprime, xprime, wprime }));
// }
// }
// }
// }
// trueNeighbors.remove(columnIndex);
// int[] tneighbors = ArrayUtils.unique(trueNeighbors.toArray());
// assertEquals(ArrayUtils.print1DArray(tneighbors), ArrayUtils.print1DArray(neighbors));
//
// /////////////////////////////////////////
// //Tests from getNeighbors1D from Python unit test
// setupParameters();
// dimensions = new int[] { 8 };
// parameters.setColumnDimensions(dimensions);
// parameters.setInputDimensions(dimensions);
// initSP();
// AbstractSparseBinaryMatrix sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// sbm.set(new int[] { 2, 4 }, new int[] { 1, 1 }, true);
// radius = 1;
// columnIndex = 3;
// int[] mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// TIntArrayList msk = new TIntArrayList(mask);
// TIntArrayList neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //////
// setupParameters();
// dimensions = new int[] { 8 };
// parameters.setInputDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// sbm.set(new int[] { 1, 2, 4, 5 }, new int[] { 1, 1, 1, 1 }, true);
// radius = 2;
// columnIndex = 3;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //Wrap around
// setupParameters();
// dimensions = new int[] { 8 };
// parameters.setInputDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// sbm.set(new int[] { 1, 2, 6, 7 }, new int[] { 1, 1, 1, 1 }, true);
// radius = 2;
// columnIndex = 0;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //Radius too big
// setupParameters();
// dimensions = new int[] { 8 };
// parameters.setInputDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// sbm.set(new int[] { 0, 1, 2, 3, 4, 5, 7 }, new int[] { 1, 1, 1, 1, 1, 1, 1 }, true);
// radius = 20;
// columnIndex = 6;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //These are all the same tests from 2D
// setupParameters();
// dimensions = new int[] { 6, 5 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// int[][] input = new int[][] {
// {0, 0, 0, 0, 0},
// {0, 0, 0, 0, 0},
// {0, 1, 1, 1, 0},
// {0, 1, 0, 1, 0},
// {0, 1, 1, 1, 0},
// {0, 0, 0, 0, 0}};
// for(int i = 0;i < input.length;i++) {
// for(int j = 0;j < input[i].length;j++) {
// if(input[i][j] == 1)
// sbm.set(sbm.computeIndex(new int[] { i, j }), 1);
// }
// }
// radius = 1;
// columnIndex = 3*5 + 2;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// ////////
// setupParameters();
// dimensions = new int[] { 6, 5 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// input = new int[][] {
// {0, 0, 0, 0, 0},
// {1, 1, 1, 1, 1},
// {1, 1, 1, 1, 1},
// {1, 1, 0, 1, 1},
// {1, 1, 1, 1, 1},
// {1, 1, 1, 1, 1}};
// for(int i = 0;i < input.length;i++) {
// for(int j = 0;j < input[i].length;j++) {
// if(input[i][j] == 1)
// sbm.set(sbm.computeIndex(new int[] { i, j }), 1);
// }
// }
// radius = 2;
// columnIndex = 3*5 + 2;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //Radius too big
// setupParameters();
// dimensions = new int[] { 6, 5 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// input = new int[][] {
// {1, 1, 1, 1, 1},
// {1, 1, 1, 1, 1},
// {1, 1, 1, 1, 1},
// {1, 1, 0, 1, 1},
// {1, 1, 1, 1, 1},
// {1, 1, 1, 1, 1}};
// for(int i = 0;i < input.length;i++) {
// for(int j = 0;j < input[i].length;j++) {
// if(input[i][j] == 1)
// sbm.set(sbm.computeIndex(new int[] { i, j }), 1);
// }
// }
// radius = 7;
// columnIndex = 3*5 + 2;
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
//
// //Wrap-around
// setupParameters();
// dimensions = new int[] { 6, 5 };
// parameters.setInputDimensions(dimensions);
// parameters.setColumnDimensions(dimensions);
// initSP();
// sbm = (AbstractSparseBinaryMatrix)mem.getInputMatrix();
// input = new int[][] {
// {1, 0, 0, 1, 1},
// {0, 0, 0, 0, 0},
// {0, 0, 0, 0, 0},
// {0, 0, 0, 0, 0},
// {1, 0, 0, 1, 1},
// {1, 0, 0, 1, 0}};
// for(int i = 0;i < input.length;i++) {
// for(int j = 0;j < input[i].length;j++) {
// if(input[i][j] == 1)
// sbm.set(sbm.computeIndex(new int[] { i, j }), 1);
// }
// }
// radius = 1;
// columnIndex = sbm.getMaxIndex();
// mask = sp.getNeighborsND(mem, columnIndex, mem.getInputMatrix(), radius, true).toArray();
// msk = new TIntArrayList(mask);
// neg = new TIntArrayList(ArrayUtils.range(0, dimensions[0]));
// neg.removeAll(msk);
// assertTrue(sbm.all(mask));
// assertFalse(sbm.any(neg));
// }
@Test
public void testInit() {
setupParameters();
parameters.setNumActiveColumnsPerInhArea(0);
parameters.setLocalAreaDensity(0);
Connections c = new Connections();
parameters.apply(c);
SpatialPooler sp = new SpatialPooler();
// Local Area Density cannot be 0
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Inhibition parameters are invalid", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
// Local Area Density can't be above 0.5
parameters.setLocalAreaDensity(0.51);
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Inhibition parameters are invalid", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
// Local Area Density should be sane here
parameters.setLocalAreaDensity(0.5);
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
}catch(Exception e) {
fail();
}
// Num columns cannot be 0
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 0 });
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Invalid number of columns: 0", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
// Reset column dims
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 5 });
// Num columns cannot be 0
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 0 });
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Invalid number of inputs: 0", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
}
@Test
public void testComputeInputMismatch() {
setupParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 2, 4 });
parameters.setColumnDimensions(new int[] { 5, 1 });
Connections c = new Connections();
parameters.apply(c);
int misMatchedDims = 6; // not 8
SpatialPooler sp = new SpatialPooler();
sp.init(c);
try {
sp.compute(c, new int[misMatchedDims], new int[25], true);
fail();
}catch(Exception e) {
assertEquals("Input array must be same size as the defined number"
+ " of inputs: From Params: 8, From Input Vector: 6", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
// Now Do the right thing
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 2, 4 });
parameters.setColumnDimensions(new int[] { 5, 1 });
c = new Connections();
parameters.apply(c);
int matchedDims = 8; // same as input dimension multiplied, above
sp.init(c);
try {
sp.compute(c, new int[matchedDims], new int[25], true);
}catch(Exception e) {
fail();
}
}
}