/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.NamedTuple;
import rx.functions.Func1;
/**
* <p>
* Abstraction used within the Network API, to contain the significant return values of all {@link Layer}
* inference participating algorithms.
* </p>
* Namely:
* <ul>
* <li>Input Value</li>
* <li>Bucket Index</li>
* <li>SDR</li>
* <li>Previous SDR</li>
* <li>{@link Classification}</li>
* <li>anomalyScore</li>
* </ul>
*
* All of these fields are "optional", (meaning they depend on the configuration
* selected by the user and may not exist depending on the user's choice of "terminal"
* point. A "Terminal" point is the end point in a chain of a {@code Layer}'s contained
* algorithms. For instance, if the user does not include an {@link Encoder} in the
* {@link Layer} constructor, the slot containing the "Bucket Index" will be empty.
*
* @author David Ray
*
*/
public class ManualInput implements Inference {
private static final long serialVersionUID = 1L;
private int recordNum;
/** Tuple = { Name, inputValue, bucketIndex, encoding } */
private Map<String, NamedTuple> classifierInput;
/** Holds one classifier for each field */
NamedTuple classifiers;
private Object layerInput;
private int[] sdr;
private int[] encoding;
/** Active columns in the {@link SpatialPooler} at time "t" */
private int[] feedForwardActiveColumns;
/** Active column indexes from the {@link SpatialPooler} at time "t" */
private int[] feedForwardSparseActives;
/** Predictive {@link Cell}s in the {@link TemporalMemory} at time "t - 1" */
private Set<Cell> previousPredictiveCells;
/** Predictive {@link Cell}s in the {@link TemporalMemory} at time "t" */
private Set<Cell> predictiveCells;
/** Active {@link Cell}s in the {@link TemporalMemory} at time "t" */
private Set<Cell> activeCells;
private Map<String, Classification<Object>> classification;
private double anomalyScore;
private Object customObject;
ComputeCycle computeCycle;
/**
* Constructs a new {@code ManualInput}
*/
public ManualInput() {}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public <T> T postDeSerialize(T manualInput) {
ManualInput mi = (ManualInput)manualInput;
ManualInput retVal = new ManualInput();
retVal.activeCells = mi.activeCells;
retVal.anomalyScore = mi.anomalyScore;
retVal.classification = mi.classification;
retVal.classifierInput = mi.classifierInput;
retVal.classifiers = mi.classifiers;
retVal.customObject = mi.customObject;
retVal.encoding = mi.encoding;
retVal.feedForwardActiveColumns = mi.feedForwardActiveColumns;
retVal.feedForwardSparseActives = mi.feedForwardSparseActives;
retVal.layerInput = mi.layerInput;
retVal.predictiveCells = mi.predictiveCells;
retVal.previousPredictiveCells = mi.previousPredictiveCells;
retVal.sdr = mi.sdr;
return (T)retVal;
}
/**
* Sets the current record num associated with this {@code ManualInput}
* instance
*
* @param num the current sequence number.
* @return this
*/
public ManualInput recordNum(int num) {
this.recordNum = num;
return this;
}
/**
* Returns the current record num associated with this {@code ManualInput}
* instance
*
* @return the current sequence number
*/
@Override
public int getRecordNum() {
return recordNum;
}
/**
* Sets the {@link ComputeCycle} from the TemporalMemory
* @param computeCycle
*/
public ManualInput computeCycle(ComputeCycle computeCycle) {
this.computeCycle = computeCycle;
return this;
}
/**
* Returns the {@link ComputeCycle}
* @return
*/
@Override
public ComputeCycle getComputeCycle() {
return computeCycle;
}
/**
* Returns a custom Object during sequence processing where one or more
* {@link Func1}(s) were added to a {@link Layer} in between algorithmic
* components.
*
* @return the custom object set during processing
*/
@Override
public Object getCustomObject() {
return customObject;
}
/**
* Sets a custom Object during sequence processing where one or more
* {@link Func1}(s) were added to a {@link Layer} in between algorithmic
* components.
*
* @param o
* @return
*/
public ManualInput customObject(Object o) {
this.customObject = o;
return this;
}
/**
* <p>
* Returns the {@link Map} used as input into the field's {@link Classifier}
* (it is only actually used as input if a Classifier type has specified for
* the field).
*
* This mapping contains the name of the field being classified mapped
* to a {@link NamedTuple} containing:
* </p><p>
* <ul>
* <li>name</li>
* <li>inputValue</li>
* <li>bucketIdx</li>
* <li>encoding</li>
* </ul>
*
* @return the current classifier input
*/
@Override
public Map<String, NamedTuple> getClassifierInput() {
if(classifierInput == null) {
classifierInput = new HashMap<String, NamedTuple>();
}
return classifierInput;
}
/**
* Sets the current
* @param classifierInput
* @return
*/
ManualInput classifierInput(Map<String, NamedTuple> classifierInput) {
this.classifierInput = classifierInput;
return this;
}
/**
* Sets the {@link NamedTuple} containing the classifiers used
* for each particular input field.
*
* @param tuple
* @return
*/
public ManualInput classifiers(NamedTuple tuple) {
this.classifiers = tuple;
return this;
}
/**
* Returns a {@link NamedTuple} keyed to the input field
* names, whose values are the {@link Classifier} used
* to track the classification of a particular field
*/
@Override
public NamedTuple getClassifiers() {
return classifiers;
}
/**
* Returns the most recent input object
*
* @return the input
*/
@Override
public Object getLayerInput() {
return layerInput;
}
/**
* Sets the input object to be used and returns
* this {@link ManualInput}
*
* @param inputValue
* @return
*/
ManualInput layerInput(Object inputValue) {
this.layerInput = inputValue;
return this;
}
/**
* Returns the <em>Sparse Distributed Representation</em> vector
* which is the result of all algorithms in a series of algorithms
* passed up the hierarchy.
*
* @return
*/
@Override
public int[] getSDR() {
return sdr;
}
/**
* Inputs an sdr and returns this {@code ManualInput}
*
* @param sdr
* @return
*/
ManualInput sdr(int[] sdr) {
this.sdr = sdr;
return this;
}
/**
* Returns the initial encoding produced by an {@link Encoder}
* or one of its subtypes.
*
* @return
*/
@Override
public int[] getEncoding() {
return encoding;
}
/**
* Inputs the initial encoding and return this {@code ManualInput}
* @param sdr
* @return
*/
ManualInput encoding(int[] sdr) {
this.encoding = sdr;
return this;
}
/**
* Convenience method to provide an isolated copy of
* this {@link Inference}
*
* @return
*/
ManualInput copy() {
ManualInput retVal = new ManualInput();
retVal.classifierInput = new HashMap<String, NamedTuple>(this.classifierInput);
retVal.classifiers = new NamedTuple(this.classifiers.keys(), this.classifiers.values().toArray());
retVal.layerInput = this.layerInput;
retVal.sdr = Arrays.copyOf(this.sdr, this.sdr.length);
retVal.encoding = Arrays.copyOf(this.encoding, this.encoding.length);
retVal.feedForwardActiveColumns = Arrays.copyOf(this.feedForwardActiveColumns, this.feedForwardActiveColumns.length);
retVal.feedForwardSparseActives = Arrays.copyOf(this.feedForwardSparseActives, this.feedForwardSparseActives.length);
retVal.previousPredictiveCells = new LinkedHashSet<Cell>(this.previousPredictiveCells);
retVal.predictiveCells = new LinkedHashSet<Cell>(this.predictiveCells);
retVal.classification = new HashMap<>(this.classification);
retVal.anomalyScore = this.anomalyScore;
retVal.customObject = this.customObject;
retVal.computeCycle = this.computeCycle;
retVal.activeCells = new LinkedHashSet<Cell>(this.activeCells);
return retVal;
}
/**
* Returns the most recent {@link Classification}
*
* @param fieldName
* @return the most recent {@link Classification}, or null if none exists.
*/
@Override
public Classification<Object> getClassification(String fieldName) {
if(classification == null)
return null;
return classification.get(fieldName);
}
/**
* Sets the specified field's last classifier computation and returns
* this {@link Inference}
*
* @param fieldName
* @param classification
* @return
*/
ManualInput storeClassification(String fieldName, Classification<Object> classification) {
if(this.classification == null) {
this.classification = new HashMap<String, Classification<Object>>();
}
this.classification.put(fieldName, classification);
return this;
}
/**
* Returns the most recent anomaly calculation.
* @return
*/
@Override
public double getAnomalyScore() {
return anomalyScore;
}
/**
* Sets the current computed anomaly score and
* returns this {@link Inference}
*
* @param d
* @return
*/
ManualInput anomalyScore(double d) {
this.anomalyScore = d;
return this;
}
/**
* Returns the column activation from a {@link SpatialPooler}
* @return
*/
@Override
public int[] getFeedForwardActiveColumns() {
return feedForwardActiveColumns;
}
/**
* Sets the column activation from a {@link SpatialPooler}
* @param cols
* @return
*/
public ManualInput feedForwardActiveColumns(int[] cols) {
this.feedForwardActiveColumns = cols;
return this;
}
/**
* Returns the column activation from a {@link TemporalMemory}
* @return
*/
@Override
public Set<Cell> getActiveCells() {
return activeCells;
}
/**
* Sets the column activation from a {@link TemporalMemory}
* @param cells
* @return
*/
public ManualInput activeCells(Set<Cell> cells) {
this.activeCells = cells;
return this;
}
/**
* Returns the column activations in sparse form
* @return
*/
@Override
public int[] getFeedForwardSparseActives() {
if(feedForwardSparseActives == null && feedForwardActiveColumns != null) {
feedForwardSparseActives = ArrayUtils.where(feedForwardActiveColumns, ArrayUtils.WHERE_1);
}
return feedForwardSparseActives;
}
/**
* Sets the column activations in sparse form.
* @param cols
* @return
*/
public ManualInput feedForwardSparseActives(int[] cols) {
this.feedForwardSparseActives = cols;
return this;
}
/**
* Returns the predicted output from the last inference cycle.
* @return
*/
@Override
public Set<Cell> getPreviousPredictiveCells() {
return previousPredictiveCells;
}
/**
* Sets the previous predicted columns.
* @param cells
* @return
*/
public ManualInput previousPredictiveCells(Set<Cell> cells) {
this.previousPredictiveCells = cells;
return this;
}
/**
* Returns the currently predicted columns.
* @return
*/
@Override
public Set<Cell> getPredictiveCells() {
return predictiveCells;
}
/**
* Sets the currently predicted columns
* @param cells
* @return
*/
public ManualInput predictiveCells(Set<Cell> cells) {
previousPredictiveCells = predictiveCells;
this.predictiveCells = cells;
return this;
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((activeCells == null) ? 0 : activeCells.hashCode());
long temp;
temp = Double.doubleToLongBits(anomalyScore);
result = prime * result + (int)(temp ^ (temp >>> 32));
result = prime * result + ((classification == null) ? 0 : classification.hashCode());
result = prime * result + ((classifierInput == null) ? 0 : classifierInput.hashCode());
result = prime * result + ((computeCycle == null) ? 0 : computeCycle.hashCode());
result = prime * result + Arrays.hashCode(encoding);
result = prime * result + Arrays.hashCode(feedForwardActiveColumns);
result = prime * result + Arrays.hashCode(feedForwardSparseActives);
result = prime * result + ((predictiveCells == null) ? 0 : predictiveCells.hashCode());
result = prime * result + ((previousPredictiveCells == null) ? 0 : previousPredictiveCells.hashCode());
result = prime * result + recordNum;
result = prime * result + Arrays.hashCode(sdr);
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(!Inference.class.isAssignableFrom(obj.getClass()))
return false;
ManualInput other = (ManualInput)obj;
if(activeCells == null) {
if(other.activeCells != null)
return false;
} else if(!activeCells.equals(other.activeCells))
return false;
if(Double.doubleToLongBits(anomalyScore) != Double.doubleToLongBits(other.anomalyScore))
return false;
if(classification == null) {
if(other.classification != null)
return false;
} else if(!classification.equals(other.classification))
return false;
if(classifierInput == null) {
if(other.classifierInput != null)
return false;
} else if(!classifierInput.equals(other.classifierInput))
return false;
if(computeCycle == null) {
if(other.computeCycle != null)
return false;
} else if(!computeCycle.equals(other.computeCycle))
return false;
if(!Arrays.equals(encoding, other.encoding))
return false;
if(!Arrays.equals(feedForwardActiveColumns, other.feedForwardActiveColumns))
return false;
if(!Arrays.equals(feedForwardSparseActives, other.feedForwardSparseActives))
return false;
if(predictiveCells == null) {
if(other.predictiveCells != null)
return false;
} else if(!predictiveCells.equals(other.predictiveCells))
return false;
if(previousPredictiveCells == null) {
if(other.previousPredictiveCells != null)
return false;
} else if(!previousPredictiveCells.equals(other.previousPredictiveCells))
return false;
if(recordNum != other.recordNum)
return false;
if(!Arrays.equals(sdr, other.sdr))
return false;
return true;
}
}