/* --------------------------------------------------------------------- * Numenta Platform for Intelligent Computing (NuPIC) * Copyright (C) 2014, Numenta, Inc. Unless you have an agreement * with Numenta, Inc., for a separate license for this software code, the * following terms and conditions apply: * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero Public License version 3 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU Affero Public License for more details. * * You should have received a copy of the GNU Affero Public License * along with this program. If not, see http://www.gnu.org/licenses. * * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ package org.numenta.nupic.monitor.mixin; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import org.numenta.nupic.algorithms.TemporalMemory; import org.numenta.nupic.model.Cell; import org.numenta.nupic.model.Column; import org.numenta.nupic.model.ComputeCycle; import org.numenta.nupic.model.DistalDendrite; import org.numenta.nupic.model.Connections; import org.numenta.nupic.model.Segment; import org.numenta.nupic.model.Synapse; import org.numenta.nupic.monitor.ComputeDecorator; import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.Tuple; import com.bethecoder.table.AsciiTableInstance; import com.bethecoder.table.spec.AsciiTable; /** * Contains methods to create the {@link Trace}s used to gather test results * and create {@link Metric}s from them for analysis and pretty-printing * * This interface contains "defender" methods or Traits that are used to collect * result data for the {@link TemporalMemory}. * * @author cogmission * */ public interface TemporalMemoryMonitorMixin extends MonitorMixinBase { /** * Returns the ComputeDecorator mixin target * * @return */ @SuppressWarnings("unchecked") public ComputeDecorator getMonitor(); /** * Returns the resetActive flag * @return */ public boolean resetActive(); /** * Sets the resetActive flag * @param b */ public void setResetActive(boolean b); /** * Returns the flag indicating whether the current traces * are stale and need to be recomputed, or not. * * @return */ public boolean transitionTracesStale(); /** * Sets the flag indicating whether the current traces * are stale and need to be recomputed, or not. * * @param b */ public void setTransitionTracesStale(boolean b); /** * Returns Trace of the active {@link Column} indexes. * @return */ default IndicesTrace mmGetTraceActiveColumns() { return (IndicesTrace)getTraceMap().get("activeColumns"); } /** * Returns Trace of the active {@link Cell} indexes. * @return */ default IndicesTrace mmGetTracePredictiveCells() { return (IndicesTrace)getTraceMap().get("predictiveCells"); } /** * Returns Trace count of {@link Segment}s * @return */ default CountsTrace mmGetTraceNumSegments() { return (CountsTrace)getTraceMap().get("numSegments"); } /** * Returns Trace count of {@link Synapse}s * @return */ default CountsTrace mmGetTraceNumSynapses() { return (CountsTrace)getTraceMap().get("numSynapses"); } /** * Returns Trace containing a sequence's labels * @return */ default StringsTrace mmGetTraceSequenceLabels() { return (StringsTrace)getTraceMap().get("sequenceLabels"); } /** * Returns Trace containing targeted resets for a given sequence * @return */ default BoolsTrace mmGetTraceResets() { return (BoolsTrace)getTraceMap().get("resets"); } /** * Trace of predicted => active cells * @param c * @return */ default IndicesTrace mmGetTracePredictedActiveCells() { mmComputeTransitionTraces(); return (IndicesTrace)getTraceMap().get("predictedActiveCells"); } /** * Trace of predicted => inactive cells * @return */ default IndicesTrace mmGetTracePredictedInactiveCells() { mmComputeTransitionTraces(); return (IndicesTrace)getTraceMap().get("predictedInactiveCells"); } /** * Trace of predicted => active columns * @return */ default IndicesTrace mmGetTracePredictedActiveColumns() { mmComputeTransitionTraces(); return (IndicesTrace)getTraceMap().get("predictedActiveColumns"); } /** * Trace of predicted => inactive columns * @return */ default IndicesTrace mmGetTracePredictedInactiveColumns() { mmComputeTransitionTraces(); return (IndicesTrace)getTraceMap().get("predictedInactiveColumns"); } /** * Trace of unpredicted => active columns * @return */ default IndicesTrace mmGetTraceUnpredictedActiveColumns() { mmComputeTransitionTraces(); return (IndicesTrace)getTraceMap().get("unpredictedActiveColumns"); } /** * Convenience method to compute a metric over an counts trace, excluding * resets. * * @param trace Trace of indices * @return */ default Metric mmGetMetricFromTrace(Trace<Number> trace) { return Metric.createFromTrace(trace, mmGetTraceResets()); } /** * Convenience method to compute a metric over an indices trace, excluding * resets. * * @param trace Trace of indices * @return */ default Metric mmGetMetricFromTrace(IndicesTrace trace) { List<LinkedHashSet<Integer>> data = null; BoolsTrace excludeResets = mmGetTraceResets(); if(excludeResets != null) { int[] i = { 0 }; data = trace.items.stream().filter(t -> !excludeResets.items.get(i[0]++)).collect(Collectors.toList()); } trace.items = data; CountsTrace iTrace = trace.makeCountsTrace(); return Metric.createFromTrace(iTrace, mmGetTraceResets()); } /** * Metric for number of predicted => active cells per column for each sequence * @return */ @SuppressWarnings("unchecked") default Metric mmGetMetricSequencesPredictedActiveCellsPerColumn() { mmComputeTransitionTraces(); List<Integer> numCellsPerColumn = new ArrayList<>(); for(Map.Entry<String, Set<Integer>> m : ((Map<String, Set<Integer>>)getDataMap().get("predictedActiveCellsForSequence")).entrySet()) { numCellsPerColumn.add(m.getValue().size()); } return new Metric(this, "# predicted => active cells per column for each sequence", numCellsPerColumn); } /** * Metric for number of sequences each predicted => active cell appears in * * Note: This metric is flawed when it comes to high-order sequences. * @return */ @SuppressWarnings("unchecked") default Metric mmGetMetricSequencesPredictedActiveCellsShared() { mmComputeTransitionTraces(); Map<Integer, Integer> numSequencesForCell = new HashMap<>(); for(Map.Entry<String, Set<Integer>> m : ((Map<String, Set<Integer>>)getDataMap().get("predictedActiveCellsForSequence")).entrySet()) { for(Integer cell : m.getValue()) { if(numSequencesForCell.get(cell) == null) { numSequencesForCell.put(cell, 0); continue; } numSequencesForCell.put(cell, numSequencesForCell.get(cell) + 1); } } return new Metric(this, "# sequences each predicted => active cells appears in", new ArrayList<>(numSequencesForCell.values())); } /** * Pretty print the connections in the temporal memory. * * @return */ default String mmPrettyPrintConnections() { StringBuilder text = new StringBuilder(); text.append("Segments: (format => (#) [(source cell=permanence ...), ...]\n") .append("------------------------------------\n"); Connections cnx = getConnections(); List<Integer> columns = Arrays.asList( ArrayUtils.toBoxed( ArrayUtils.range(0, cnx.getNumColumns()))); for(Integer column : columns) { int[] cells = cnx.getColumn(column).getCells(). stream().map(c -> c.getIndex()).mapToInt(i->i).toArray(); for(int cell : cells) { Map<Integer, String> segmentDict = new HashMap<>(); for(DistalDendrite dd : cnx.getSegments(cnx.getCell(cell))) { List<Tuple> synapseList = new ArrayList<Tuple>(); for(Synapse s : cnx.getSynapses(dd)) { Tuple synapseData = new Tuple(s.getInputIndex(), s.getPermanence()); synapseList.add(synapseData); } Stream<Tuple> tupes = synapseList.stream().sorted( (Tuple t1, Tuple t2) -> ((Integer)t1.get(0)).compareTo((Integer)t2.get(0))); List<String> synapseStringList = tupes.map(t -> String.format("%3d=%.2f", t.get(0), t.get(1))).collect(Collectors.toList()); segmentDict.put(dd.getIndex(), String.format("(%s)", synapseStringList.stream().collect(Collectors.joining(" ")))); } text.append(String.format("Column %3d / Cell %3d:\t(%d) %s\n", column, cell, segmentDict.values().size(), String.format("[%s]", segmentDict.values().stream().collect(Collectors.joining(", "))))); } if(column < columns.size() - 1) { text.append("\n"); } } text.append("------------------------------------\n"); return text.toString(); } /** * Pretty print the cell representations for sequences in the history. * @return */ @SuppressWarnings("unchecked") default String mmPrettyPrintSequenceCellRepresentations(String sortBy) { mmComputeTransitionTraces(); String[] header = { "Pattern", "Column", "predicted=>active cells" }; // Check required sort column header to see if it exists, and get index int sortIndex = -1; int idx = -1; for(String colHeader : header) { idx++; if(colHeader.equals(sortBy)) { sortIndex = idx; break; } } if(sortIndex == -1) { throw new IllegalArgumentException("No header named \"" + sortBy + "\" to sort by."); } String[][] data = new String[getDataMap().get("predictedActiveCellsForSequence").values().size()][]; int i = 0; for(Map.Entry<String, Set<Integer>> m : ((Map<String, Set<Integer>>)getDataMap().get("predictedActiveCellsForSequence")).entrySet()) { Map<Integer, List<Integer>> cellsForColumn = m.getValue().stream().collect( Collectors.groupingBy(cell -> getConnections().getCell(cell).getColumn().getIndex())); for(Integer column : cellsForColumn.keySet()) { data[i] = new String[] { m.getKey(), column.toString(), cellsForColumn.get(column).toString().replace("[", "").replace("]", "") }; i++; } } // Sort the data int finalIndex = sortIndex; Arrays.stream(data).sorted((sa1, sa2) -> sa1[finalIndex].compareTo(sa2[finalIndex])); String retVal = AsciiTableInstance.get().getTable(header, data, AsciiTable.ALIGN_CENTER); return retVal; } // ========================= // Helper Methods // ========================= /** * Computes the transition traces, if necessary. * * Transition traces are the following: * * predicted => active cells * predicted => inactive cells * predicted => active columns * predicted => inactive columns * unpredicted => active columns */ @SuppressWarnings("unchecked") default void mmComputeTransitionTraces() { if(!transitionTracesStale()) { return; } Map<String, Set<Integer>> predActCells = null; if((predActCells = (Map<String, Set<Integer>>)getDataMap() .get("predictedActiveCellsForSequence")) == null) { getDataMap().put("predictedActiveCellsForSequence", predActCells = new HashMap<String, Set<Integer>>()); } getTraceMap().put("predictedActiveCells", new IndicesTrace(this, "predicted => active cells (correct)")); getTraceMap().put("predictedInactiveCells", new IndicesTrace(this, "predicted => inactive cells (extra)")); getTraceMap().put("predictedActiveColumns", new IndicesTrace(this, "predicted => active columns (correct)")); getTraceMap().put("predictedInactiveColumns", new IndicesTrace(this, "predicted => inactive columns (extra)")); getTraceMap().put("unpredictedActiveColumns", new IndicesTrace(this, "unpredicted => active columns (bursting)")); IndicesTrace predictedCellsTrace = (IndicesTrace)getTraceMap().get("predictedCells"); int i = 0;LinkedHashSet<Integer> predictedActiveColumns = null; for(Set<Integer> activeColumns : mmGetTraceActiveColumns().items) { LinkedHashSet<Integer> predictedActiveCells = new LinkedHashSet<>(); LinkedHashSet<Integer> predictedInactiveCells = new LinkedHashSet<>(); predictedActiveColumns = new LinkedHashSet<>(); LinkedHashSet<Integer> predictedInactiveColumns = new LinkedHashSet<>(); for(Integer predictedCell : predictedCellsTrace.items.get(i)) { Integer predictedColumn = getConnections().getCell(predictedCell).getColumn().getIndex(); if(activeColumns.contains(predictedColumn)) { predictedActiveCells.add(predictedCell); predictedActiveColumns.add(predictedColumn); String sequenceLabel = (String)mmGetTraceSequenceLabels().items.get(i); if(sequenceLabel != null && !sequenceLabel.isEmpty()) { Set<Integer> sequencePredictedCells = null; if((sequencePredictedCells = (Set<Integer>)predActCells.get(sequenceLabel)) == null) { ((Map<String, Set<Integer>>)predActCells).put( sequenceLabel, sequencePredictedCells = new LinkedHashSet<Integer>()); } sequencePredictedCells.add(predictedCell); } }else{ predictedInactiveCells.add(predictedCell); predictedInactiveColumns.add(predictedColumn); } } LinkedHashSet<Integer> unpredictedActiveColumns = new LinkedHashSet<>(activeColumns); unpredictedActiveColumns.removeAll(predictedActiveColumns); ((IndicesTrace)getTraceMap().get("predictedActiveCells")).items.add(predictedActiveCells); ((IndicesTrace)getTraceMap().get("predictedInactiveCells")).items.add(predictedInactiveCells); ((IndicesTrace)getTraceMap().get("predictedActiveColumns")).items.add(predictedActiveColumns); ((IndicesTrace)getTraceMap().get("predictedInactiveColumns")).items.add(predictedInactiveColumns); ((IndicesTrace)getTraceMap().get("unpredictedActiveColumns")).items.add(unpredictedActiveColumns); i++; } setTransitionTracesStale(false); } // ========================= // Overrides // ========================= default ComputeCycle compute(Connections cnx, int[] activeColumns, String sequenceLabel, boolean learn) { // Append last cycle's predictiveCells to *predicTEDCells* trace ((IndicesTrace)getTraceMap().get("predictedCells")).items.add( new LinkedHashSet<Integer>(Connections.asCellIndexes(cnx.getPredictiveCells()))); ComputeCycle cycle = getMonitor().compute(cnx, activeColumns, learn); // Append this cycle's predictiveCells to *predicTIVECells* trace ((IndicesTrace)getTraceMap().get("predictiveCells")).items.add( new LinkedHashSet<Integer>(Connections.asCellIndexes(cnx.getPredictiveCells()))); ((IndicesTrace)getTraceMap().get("activeCells")).items.add( new LinkedHashSet<Integer>(Connections.asCellIndexes(cnx.getActiveCells()))); ((IndicesTrace)getTraceMap().get("activeColumns")).items.add( Arrays.stream(activeColumns).boxed().collect(Collectors.toCollection(LinkedHashSet::new))); ((CountsTrace)getTraceMap().get("numSegments")).items.add(cnx.numSegments()); ((CountsTrace)getTraceMap().get("numSynapses")).items.add((int)(cnx.numSynapses() ^ (cnx.numSynapses() >>> 32))); ((StringsTrace)getTraceMap().get("sequenceLabels")).items.add(sequenceLabel); ((BoolsTrace)getTraceMap().get("resets")).items.add(resetActive()); setResetActive(false); setTransitionTracesStale(true); return cycle; } /** * Called to delegate a {@link TemporalMemory#reset(Connections)} call and * then set a flag locally which controls remaking of test {@link Trace}s. * * @param c */ default void resetSequences(Connections c) { getMonitor().reset(c); setResetActive(true); } /** * Returns a list of {@link Trace} objects containing data sets used * to analyze the behavior and state of the {@link TemporalMemory} This * method is called from all of the "mmXXX" methods to make sure that * the data represents the most current execution cycle of the TM. * * @param verbosity setting which controls how much to print out. * @return List of {@link Trace}s */ @SuppressWarnings("unchecked") default <T extends Trace<?>> List<T> mmGetDefaultTraces(int verbosity) { List<T> traces = new ArrayList<>(); traces.add((T)mmGetTraceActiveColumns()); traces.add((T)mmGetTracePredictedActiveColumns()); traces.add((T)mmGetTracePredictedInactiveColumns()); traces.add((T)mmGetTraceUnpredictedActiveColumns()); traces.add((T)mmGetTracePredictedActiveCells()); traces.add((T)mmGetTracePredictedInactiveCells()); List<T> tracesToAdd = new ArrayList<>(); if(verbosity == 1) { for(Trace<?> t : traces) { tracesToAdd.add((T)((IndicesTrace)t).makeCountsTrace()); } traces.clear(); traces.addAll(tracesToAdd); } traces.add((T)mmGetTraceNumSegments()); traces.add((T)mmGetTraceNumSynapses()); traces.add((T)mmGetTraceSequenceLabels()); return traces; } /** * Returns a list of {@link Metric} objects containing statistics used * to analyze the behavior and state of the {@link TemporalMemory} This * method is called from all of the "mmXXX" methods to make sure that * the data represents the most current execution cycle of the TM. * * @param verbosity setting which controls how much to print out. * @return List of {@link Trace}s */ @SuppressWarnings("unchecked") default List<Metric> mmGetDefaultMetrics(int verbosity) { BoolsTrace resetsTrace = mmGetTraceResets(); List<Metric> metrics = new ArrayList<>(); List<?> utilTraces = mmGetDefaultTraces(verbosity); for(int i = 0;i < utilTraces.size() - 3;i++) { metrics.add(Metric.createFromTrace((Trace<Number>)utilTraces.get(i), resetsTrace)); } for(int i = utilTraces.size() - 3;i < utilTraces.size() - 1;i++) { metrics.add(Metric.createFromTrace((Trace<Number>)utilTraces.get(i), null)); } metrics.add(mmGetMetricSequencesPredictedActiveCellsPerColumn()); metrics.add(mmGetMetricSequencesPredictedActiveCellsShared()); return metrics; } /** * Clears the map of all {@link Trace}s */ default void mmClearHistory() { getTraceMap().clear(); getDataMap().clear(); getTraceMap().put("predictedCells", new IndicesTrace(this, "predicted cells")); getTraceMap().put("activeColumns", new IndicesTrace(this, "active columns")); getTraceMap().put("activeCells", new IndicesTrace(this, "active cells")); getTraceMap().put("predictiveCells", new IndicesTrace(this, "predictive cells")); getTraceMap().put("numSegments", new CountsTrace(this, "# segments")); getTraceMap().put("numSynapses", new CountsTrace(this, "# synapses")); getTraceMap().put("sequenceLabels", new StringsTrace(this, "sequence labels")); getTraceMap().put("resets", new BoolsTrace(this, "resets")); setTransitionTracesStale(true); } }