package org.numenta.nupic.model; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.junit.Test; import org.numenta.nupic.Parameters; import org.numenta.nupic.Parameters.KEY; import org.numenta.nupic.algorithms.TemporalMemory; import org.numenta.nupic.algorithms.TemporalMemory.ColumnData; import org.numenta.nupic.util.GroupBy2; import org.numenta.nupic.util.Tuple; import org.numenta.nupic.util.UniversalRandom; import chaschev.lang.Pair; public class ComputeCycleTest { @Test public void testConversionConstructor() { Column column = new Column(10, 0); List<Cell> cells = IntStream.range(0, 10) .mapToObj(i -> new Cell(column, i)) .collect(Collectors.toList()); Connections cnx = new Connections(); cnx.setActiveCells(new LinkedHashSet<Cell>( Arrays.asList( new Cell[] { cells.get(0), cells.get(1), cells.get(2), cells.get(3) }))); cnx.setWinnerCells(new LinkedHashSet<Cell>( Arrays.asList( new Cell[] { cells.get(1), cells.get(3), }))); ComputeCycle cc = new ComputeCycle(cnx); assertNotNull(cc.activeCells); assertEquals(4, cc.activeCells.size()); assertNotNull(cc.winnerCells); assertEquals(2, cc.winnerCells.size()); assertNotNull(cc.predictiveCells()); assertEquals(0, cc.predictiveCells().size()); ComputeCycle cc1 = new ComputeCycle(cnx); assertEquals(cc, cc1); assertEquals(cc.hashCode(), cc1.hashCode()); // Now test negative equality cnx.setWinnerCells(new LinkedHashSet<Cell>( Arrays.asList( new Cell[] { cells.get(4), cells.get(3), }))); ComputeCycle cc2 = new ComputeCycle(cnx); assertNotEquals(cc1, cc2); assertFalse(cc1.hashCode() == cc2.hashCode()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @Test public void testActiveColumnsRetrievable() { TemporalMemory tm = new TemporalMemory(); Connections cn = new Connections(); Parameters p = getDefaultParameters(null, KEY.CELLS_PER_COLUMN, 1); p = getDefaultParameters(p, KEY.MIN_THRESHOLD, 1); p.apply(cn); TemporalMemory.init(cn); int[] previousActiveColumns = { 0, 1, 2, 3 }; Set<Cell> prevWinnerCells = cn.getCellSet(new int[] { 0, 1, 2, 3 }); int[] activeColumnsIndices = { 4 }; DistalDendrite matchingSegment = cn.createSegment(cn.getCell(4)); cn.createSynapse(matchingSegment, cn.getCell(0), 0.5); ComputeCycle cc = tm.compute(cn, previousActiveColumns, true); assertTrue(cc.winnerCells().equals(prevWinnerCells)); //cc = tm.compute(cn, activeColumnsIndices, true); Function<Column, Column> identity = Function.identity(); Function<DistalDendrite, Column> segToCol = segment -> segment.getParentCell().getColumn(); List<Column> activeColumns = Arrays.stream(activeColumnsIndices) .sorted() .mapToObj(i -> cn.getColumn(i)) .collect(Collectors.toList()); GroupBy2<Column> grouper = GroupBy2.<Column>of( new Pair(activeColumns, identity), new Pair(new ArrayList(cn.getActiveSegments()), segToCol), new Pair(new ArrayList(cn.getMatchingSegments()), segToCol)); ColumnData columnData = new ColumnData(); for(Tuple t : grouper) { // Executes only once columnData = columnData.set(t); assertTrue(columnData.activeColumns().equals(activeColumns)); assertTrue(columnData.activeSegments().isEmpty()); List<DistalDendrite> sos = columnData.matchingSegments(); assertEquals(1, sos.size()); assertEquals(0, sos.get(0).getIndex()); assertEquals(4, sos.get(0).getParentCell().getIndex()); assertTrue(columnData.column().equals(cn.getColumn(4))); } } private Parameters getDefaultParameters(Parameters p, KEY key, Object value) { Parameters retVal = p == null ? getDefaultParameters() : p; retVal.set(key, value); return retVal; } private Parameters getDefaultParameters() { Parameters retVal = Parameters.getTemporalDefaultParameters(); retVal.set(KEY.COLUMN_DIMENSIONS, new int[] { 32 }); retVal.set(KEY.CELLS_PER_COLUMN, 4); retVal.set(KEY.ACTIVATION_THRESHOLD, 3); retVal.set(KEY.INITIAL_PERMANENCE, 0.21); retVal.set(KEY.CONNECTED_PERMANENCE, 0.5); retVal.set(KEY.MIN_THRESHOLD, 2); retVal.set(KEY.MAX_NEW_SYNAPSE_COUNT, 3); retVal.set(KEY.PERMANENCE_INCREMENT, 0.10); retVal.set(KEY.PERMANENCE_DECREMENT, 0.10); retVal.set(KEY.PREDICTED_SEGMENT_DECREMENT, 0.0); retVal.set(KEY.RANDOM, new UniversalRandom(42)); retVal.set(KEY.SEED, 42); return retVal; } }