/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import static org.numenta.nupic.util.GroupBy2.Slot.NONE;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Connections.Activity;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.monitor.ComputeDecorator;
import org.numenta.nupic.util.GroupBy2;
import org.numenta.nupic.util.GroupBy2.Slot;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
import chaschev.lang.Pair;
/**
* Temporal Memory implementation in Java.
*
* @author Numenta
* @author cogmission
*/
public class TemporalMemory implements ComputeDecorator, Serializable{
/** simple serial version id */
private static final long serialVersionUID = 1L;
private static final double EPSILON = 0.00001;
private static final int ACTIVE_COLUMNS = 1;
/**
* Uses the specified {@link Connections} object to Build the structural
* anatomy needed by this {@code TemporalMemory} to implement its algorithms.
*
* The connections object holds the {@link Column} and {@link Cell} infrastructure,
* and is used by both the {@link SpatialPooler} and {@link TemporalMemory}. Either of
* these can be used separately, and therefore this Connections object may have its
* Columns and Cells initialized by either the init method of the SpatialPooler or the
* init method of the TemporalMemory. We check for this so that complete initialization
* of both Columns and Cells occurs, without either being redundant (initialized more than
* once). However, {@link Cell}s only get created when initializing a TemporalMemory, because
* they are not used by the SpatialPooler.
*
* @param c {@link Connections} object
*/
public static void init(Connections c) {
SparseObjectMatrix<Column> matrix = c.getMemory() == null ?
new SparseObjectMatrix<Column>(c.getColumnDimensions()) :
c.getMemory();
c.setMemory(matrix);
int numColumns = matrix.getMaxIndex() + 1;
c.setNumColumns(numColumns);
int cellsPerColumn = c.getCellsPerColumn();
Cell[] cells = new Cell[numColumns * cellsPerColumn];
//Used as flag to determine if Column objects have been created.
Column colZero = matrix.getObject(0);
for(int i = 0;i < numColumns;i++) {
Column column = colZero == null ?
new Column(cellsPerColumn, i) : matrix.getObject(i);
for(int j = 0;j < cellsPerColumn;j++) {
cells[i * cellsPerColumn + j] = column.getCell(j);
}
//If columns have not been previously configured
if(colZero == null) matrix.set(i, column);
}
//Only the TemporalMemory initializes cells so no need to test for redundancy
c.setCells(cells);
}
@Override
public ComputeCycle compute(Connections connections, int[] activeColumns, boolean learn) {
ComputeCycle cycle = new ComputeCycle();
activateCells(connections, cycle, activeColumns, learn);
activateDendrites(connections, cycle, learn);
return cycle;
}
/**
* Calculate the active cells, using the current active columns and dendrite
* segments. Grow and reinforce synapses.
*
* <pre>
* Pseudocode:
* for each column
* if column is active and has active distal dendrite segments
* call activatePredictedColumn
* if column is active and doesn't have active distal dendrite segments
* call burstColumn
* if column is inactive and has matching distal dendrite segments
* call punishPredictedColumn
*
* </pre>
*
* @param conn
* @param activeColumnIndices
* @param learn
*/
@SuppressWarnings("unchecked")
public void activateCells(Connections conn, ComputeCycle cycle, int[] activeColumnIndices, boolean learn) {
ColumnData columnData = new ColumnData();
Set<Cell> prevActiveCells = conn.getActiveCells();
Set<Cell> prevWinnerCells = conn.getWinnerCells();
List<Column> activeColumns = Arrays.stream(activeColumnIndices)
.sorted()
.mapToObj(i -> conn.getColumn(i))
.collect(Collectors.toList());
Function<Column, Column> identity = Function.identity();
Function<DistalDendrite, Column> segToCol = segment -> segment.getParentCell().getColumn();
@SuppressWarnings({ "rawtypes" })
GroupBy2<Column> grouper = GroupBy2.<Column>of(
new Pair(activeColumns, identity),
new Pair(new ArrayList<>(conn.getActiveSegments()), segToCol),
new Pair(new ArrayList<>(conn.getMatchingSegments()), segToCol));
double permanenceIncrement = conn.getPermanenceIncrement();
double permanenceDecrement = conn.getPermanenceDecrement();
for(Tuple t : grouper) {
columnData = columnData.set(t);
if(columnData.isNotNone(ACTIVE_COLUMNS)) {
if(!columnData.activeSegments().isEmpty()) {
List<Cell> cellsToAdd = activatePredictedColumn(conn, columnData.activeSegments(),
columnData.matchingSegments(), prevActiveCells, prevWinnerCells,
permanenceIncrement, permanenceDecrement, learn);
cycle.activeCells.addAll(cellsToAdd);
cycle.winnerCells.addAll(cellsToAdd);
}else{
Tuple cellsXwinnerCell = burstColumn(conn, columnData.column(), columnData.matchingSegments(),
prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, conn.getRandom(),
learn);
cycle.activeCells.addAll((List<Cell>)cellsXwinnerCell.get(0));
cycle.winnerCells.add((Cell)cellsXwinnerCell.get(1));
}
}else{
if(learn) {
punishPredictedColumn(conn, columnData.activeSegments(), columnData.matchingSegments(),
prevActiveCells, prevWinnerCells, conn.getPredictedSegmentDecrement());
}
}
}
}
/**
* Calculate dendrite segment activity, using the current active cells.
*
* <pre>
* Pseudocode:
* for each distal dendrite segment with activity >= activationThreshold
* mark the segment as active
* for each distal dendrite segment with unconnected activity >= minThreshold
* mark the segment as matching
* </pre>
*
* @param conn the Connectivity
* @param cycle Stores current compute cycle results
* @param learn If true, segment activations will be recorded. This information is used
* during segment cleanup.
*/
public void activateDendrites(Connections conn, ComputeCycle cycle, boolean learn) {
Activity activity = conn.computeActivity(cycle.activeCells, conn.getConnectedPermanence());
List<DistalDendrite> activeSegments = IntStream.range(0, activity.numActiveConnected.length)
.filter(i -> activity.numActiveConnected[i] >= conn.getActivationThreshold())
.mapToObj(i -> conn.segmentForFlatIdx(i))
.collect(Collectors.toList());
List<DistalDendrite> matchingSegments = IntStream.range(0, activity.numActiveConnected.length)
.filter(i -> activity.numActivePotential[i] >= conn.getMinThreshold())
.mapToObj(i -> conn.segmentForFlatIdx(i))
.collect(Collectors.toList());
Collections.sort(activeSegments, conn.segmentPositionSortKey);
Collections.sort(matchingSegments, conn.segmentPositionSortKey);
cycle.activeSegments = activeSegments;
cycle.matchingSegments = matchingSegments;
conn.lastActivity = activity;
conn.setActiveCells(new LinkedHashSet<>(cycle.activeCells));
conn.setWinnerCells(new LinkedHashSet<>(cycle.winnerCells));
conn.setActiveSegments(activeSegments);
conn.setMatchingSegments(matchingSegments);
// Forces generation of the predictive cells from the above active segments
conn.clearPredictiveCells();
conn.getPredictiveCells();
if(learn) {
activeSegments.stream().forEach(s -> conn.recordSegmentActivity(s));
conn.startNewIteration();
}
}
/**
* Indicates the start of a new sequence. Clears any predictions and makes sure
* synapses don't grow to the currently active cells in the next time step.
*/
@Override
public void reset(Connections connections) {
connections.getActiveCells().clear();
connections.getWinnerCells().clear();
connections.getActiveSegments().clear();
connections.getMatchingSegments().clear();
}
/**
* Determines which cells in a predicted column should be added to winner cells
* list, and learns on the segments that correctly predicted this column.
*
* @param conn the connections
* @param activeSegments Active segments in the specified column
* @param matchingSegments Matching segments in the specified column
* @param prevActiveCells Active cells in `t-1`
* @param prevWinnerCells Winner cells in `t-1`
* @param learn If true, grow and reinforce synapses
*
* <pre>
* Pseudocode:
* for each cell in the column that has an active distal dendrite segment
* mark the cell as active
* mark the cell as a winner cell
* (learning) for each active distal dendrite segment
* strengthen active synapses
* weaken inactive synapses
* grow synapses to previous winner cells
* </pre>
*
* @return A list of predicted cells that will be added to active cells and winner
* cells.
*/
public List<Cell> activatePredictedColumn(Connections conn, List<DistalDendrite> activeSegments,
List<DistalDendrite> matchingSegments, Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells,
double permanenceIncrement, double permanenceDecrement, boolean learn) {
List<Cell> cellsToAdd = new ArrayList<>();
Cell previousCell = null;
Cell currCell;
for(DistalDendrite segment : activeSegments) {
if((currCell = segment.getParentCell()) != previousCell) {
cellsToAdd.add(currCell);
previousCell = currCell;
}
if(learn) {
adaptSegment(conn, segment, prevActiveCells, permanenceIncrement, permanenceDecrement);
int numActive = conn.getLastActivity().numActivePotential[segment.getIndex()];
int nGrowDesired = conn.getMaxNewSynapseCount() - numActive;
if(nGrowDesired > 0) {
growSynapses(conn, prevWinnerCells, segment, conn.getInitialPermanence(),
nGrowDesired, conn.getRandom());
}
}
}
return cellsToAdd;
}
/**
* Activates all of the cells in an unpredicted active column,
* chooses a winner cell, and, if learning is turned on, either adapts or
* creates a segment. growSynapses is invoked on this segment.
* </p><p>
* <b>Pseudocode:</b>
* </p><p>
* <pre>
* mark all cells as active
* if there are any matching distal dendrite segments
* find the most active matching segment
* mark its cell as a winner cell
* (learning)
* grow and reinforce synapses to previous winner cells
* else
* find the cell with the least segments, mark it as a winner cell
* (learning)
* (optimization) if there are previous winner cells
* add a segment to this winner cell
* grow synapses to previous winner cells
* </pre>
* </p>
*
* @param conn Connections instance for the TM
* @param column Bursting {@link Column}
* @param matchingSegments List of matching {@link DistalDendrite}s
* @param prevActiveCells Active cells in `t-1`
* @param prevWinnerCells Winner cells in `t-1`
* @param permanenceIncrement Amount by which permanences of synapses
* are decremented during learning
* @param permanenceDecrement Amount by which permanences of synapses
* are incremented during learning
* @param random Random number generator
* @param learn Whether or not learning is enabled
*
* @return Tuple containing:
* cells list of the processed column's cells
* bestCell the best cell
*/
public Tuple burstColumn(Connections conn, Column column, List<DistalDendrite> matchingSegments,
Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells, double permanenceIncrement, double permanenceDecrement,
Random random, boolean learn) {
List<Cell> cells = column.getCells();
Cell bestCell = null;
if(!matchingSegments.isEmpty()) {
int[] numPoten = conn.getLastActivity().numActivePotential;
Comparator<DistalDendrite> cmp = (dd1,dd2) -> numPoten[dd1.getIndex()] - numPoten[dd2.getIndex()];
DistalDendrite bestSegment = matchingSegments.stream().max(cmp).get();
bestCell = bestSegment.getParentCell();
if(learn) {
adaptSegment(conn, bestSegment, prevActiveCells, permanenceIncrement, permanenceDecrement);
int nGrowDesired = conn.getMaxNewSynapseCount() - numPoten[bestSegment.getIndex()];
if(nGrowDesired > 0) {
growSynapses(conn, prevWinnerCells, bestSegment, conn.getInitialPermanence(),
nGrowDesired, random);
}
}
}else{
bestCell = leastUsedCell(conn, cells, random);
if(learn) {
int nGrowExact = Math.min(conn.getMaxNewSynapseCount(), prevWinnerCells.size());
if(nGrowExact > 0) {
DistalDendrite bestSegment = conn.createSegment(bestCell);
growSynapses(conn, prevWinnerCells, bestSegment, conn.getInitialPermanence(),
nGrowExact, random);
}
}
}
return new Tuple(cells, bestCell);
}
/**
* Punishes the Segments that incorrectly predicted a column to be active.
*
* <p>
* <pre>
* Pseudocode:
* for each matching segment in the column
* weaken active synapses
* </pre>
* </p>
*
* @param conn Connections instance for the tm
* @param activeSegments An iterable of {@link DistalDendrite} actives
* @param matchingSegments An iterable of {@link DistalDendrite} matching
* for the column compute is operating on
* that are matching; None if empty
* @param prevActiveCells Active cells in `t-1`
* @param prevWinnerCells Winner cells in `t-1`
* are decremented during learning.
* @param predictedSegmentDecrement Amount by which segments are punished for incorrect predictions
*/
public void punishPredictedColumn(Connections conn, List<DistalDendrite> activeSegments,
List<DistalDendrite> matchingSegments, Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells,
double predictedSegmentDecrement) {
if(predictedSegmentDecrement > 0) {
for(DistalDendrite segment : matchingSegments) {
adaptSegment(conn, segment, prevActiveCells, -conn.getPredictedSegmentDecrement(), 0);
}
}
}
////////////////////////////
// Helper Methods //
////////////////////////////
/**
* Gets the cell with the smallest number of segments.
* Break ties randomly.
*
* @param conn Connections instance for the tm
* @param cells List of {@link Cell}s
* @param random Random Number Generator
*
* @return the least used {@code Cell}
*/
public Cell leastUsedCell(Connections conn, List<Cell> cells, Random random) {
List<Cell> leastUsedCells = new ArrayList<>();
int minNumSegments = Integer.MAX_VALUE;
for(Cell cell : cells) {
int numSegments = conn.numSegments(cell);
if(numSegments < minNumSegments) {
minNumSegments = numSegments;
leastUsedCells.clear();
}
if(numSegments == minNumSegments) {
leastUsedCells.add(cell);
}
}
int i = random.nextInt(leastUsedCells.size());
return leastUsedCells.get(i);
}
/**
* Creates nDesiredNewSynapes synapses on the segment passed in if
* possible, choosing random cells from the previous winner cells that are
* not already on the segment.
* <p>
* <b>Notes:</b> The process of writing the last value into the index in the array
* that was most recently changed is to ensure the same results that we get
* in the c++ implementation using iter_swap with vectors.
* </p>
*
* @param conn Connections instance for the tm
* @param prevWinnerCells Winner cells in `t-1`
* @param segment Segment to grow synapses on.
* @param initialPermanence Initial permanence of a new synapse.
* @param nDesiredNewSynapses Desired number of synapses to grow
* @param random Tm object used to generate random
* numbers
*/
public void growSynapses(Connections conn, Set<Cell> prevWinnerCells, DistalDendrite segment,
double initialPermanence, int nDesiredNewSynapses, Random random) {
List<Cell> candidates = new ArrayList<>(prevWinnerCells);
Collections.sort(candidates);
for(Synapse synapse : conn.getSynapses(segment)) {
Cell presynapticCell = synapse.getPresynapticCell();
int index = candidates.indexOf(presynapticCell);
if(index != -1) {
candidates.remove(index);
}
}
int candidatesLength = candidates.size();
int nActual = nDesiredNewSynapses < candidatesLength ? nDesiredNewSynapses : candidatesLength;
for(int i = 0;i < nActual;i++) {
int rand = random.nextInt(candidates.size());
conn.createSynapse(segment, candidates.get(rand), initialPermanence);
candidates.remove(rand);
}
}
/**
* Updates synapses on segment.
* Strengthens active synapses; weakens inactive synapses.
*
* @param conn {@link Connections} instance for the tm
* @param segment {@link DistalDendrite} to adapt
* @param prevActiveCells Active {@link Cell}s in `t-1`
* @param permanenceIncrement Amount to increment active synapses
* @param permanenceDecrement Amount to decrement inactive synapses
*/
public void adaptSegment(Connections conn, DistalDendrite segment, Set<Cell> prevActiveCells,
double permanenceIncrement, double permanenceDecrement) {
// Destroying a synapse modifies the set that we're iterating through.
List<Synapse> synapsesToDestroy = new ArrayList<>();
for(Synapse synapse : conn.getSynapses(segment)) {
double permanence = synapse.getPermanence();
if(prevActiveCells.contains(synapse.getPresynapticCell())) {
permanence += permanenceIncrement;
}else{
permanence -= permanenceDecrement;
}
// Keep permanence within min/max bounds
permanence = permanence < 0 ? 0 : permanence > 1.0 ? 1.0 : permanence;
// Use this to examine issues caused by subtle floating point differences
// be careful to set the scale (1 below) to the max significant digits right of the decimal point
// between the permanenceIncrement and initialPermanence
//
// permanence = new BigDecimal(permanence).setScale(1, RoundingMode.HALF_UP).doubleValue();
if(permanence < EPSILON) {
synapsesToDestroy.add(synapse);
}else{
synapse.setPermanence(conn, permanence);
}
}
for(Synapse s : synapsesToDestroy) {
conn.destroySynapse(s);
}
if(conn.numSynapses(segment) == 0) {
conn.destroySegment(segment);
}
}
/**
* Used in the {@link TemporalMemory#compute(Connections, int[], boolean)} method
* to make pulling values out of the {@link GroupBy2} more readable and named.
*/
@SuppressWarnings("unchecked")
public static class ColumnData implements Serializable {
/** Default Serial */
private static final long serialVersionUID = 1L;
Tuple t;
public ColumnData() {}
public ColumnData(Tuple t) {
this.t = t;
}
public Column column() { return (Column)t.get(0); }
public List<Column> activeColumns() { return (List<Column>)t.get(1); }
public List<DistalDendrite> activeSegments() {
return ((List<?>)t.get(2)).get(0).equals(Slot.empty()) ?
Collections.emptyList() :
(List<DistalDendrite>)t.get(2);
}
public List<DistalDendrite> matchingSegments() {
return ((List<?>)t.get(3)).get(0).equals(Slot.empty()) ?
Collections.emptyList() :
(List<DistalDendrite>)t.get(3);
}
public ColumnData set(Tuple t) { this.t = t; return this; }
/**
* Returns a boolean flag indicating whether the slot contained by the
* tuple at the specified index is filled with the special empty
* indicator.
*
* @param memberIndex the index of the tuple to assess.
* @return true if <em><b>not</b></em> none, false if it <em><b>is none</b></em>.
*/
public boolean isNotNone(int memberIndex) {
return !((List<?>)t.get(memberIndex)).get(0).equals(NONE);
}
}
}