/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Deque;
import org.numenta.nupic.util.Tuple;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
/**
* A CLA classifier accepts a binary input from the level below (the
* "activationPattern") and information from the sensor and encoders (the
* "classification") describing the input to the system at that time step.
*
* When learning, for every bit in activation pattern, it records a history of the
* classification each time that bit was active. The history is weighted so that
* more recent activity has a bigger impact than older activity. The alpha
* parameter controls this weighting.
*
* For inference, it takes an ensemble approach. For every active bit in the
* activationPattern, it looks up the most likely classification(s) from the
* history stored for that bit and then votes across these to get the resulting
* classification(s).
*
* This classifier can learn and infer a number of simultaneous classifications
* at once, each representing a shift of a different number of time steps. For
* example, say you are doing multi-step prediction and want the predictions for
* 1 and 3 time steps in advance. The CLAClassifier would learn the associations
* between the activation pattern for time step T and the classifications for
* time step T+1, as well as the associations between activation pattern T and
* the classifications for T+3. The 'steps' constructor argument specifies the
* list of time-steps you want.
*
* @author Numenta
* @author David Ray
* @see BitHistory
*/
public class CLAClassifier implements Persistable, Classifier {
private static final long serialVersionUID = 1L;
int verbosity = 0;
/**
* The alpha used to compute running averages of the bucket duty
* cycles for each activation pattern bit. A lower alpha results
* in longer term memory.
*/
double alpha = 0.001;
double actValueAlpha = 0.3;
/**
* The bit's learning iteration. This is updated each time store() gets
* called on this bit.
*/
int learnIteration;
/**
* This contains the offset between the recordNum (provided by caller) and
* learnIteration (internal only, always starts at 0).
*/
int recordNumMinusLearnIteration = -1;
/**
* This contains the value of the highest bucket index we've ever seen
* It is used to pre-allocate fixed size arrays that hold the weights of
* each bucket index during inference
*/
int maxBucketIdx;
/** The sequence different steps of multi-step predictions */
TIntList steps = new TIntArrayList();
/**
* History of the last _maxSteps activation patterns. We need to keep
* these so that we can associate the current iteration's classification
* with the activationPattern from N steps ago
*/
Deque<Tuple> patternNZHistory;
/**
* These are the bit histories. Each one is a BitHistory instance, stored in
* this dict, where the key is (bit, nSteps). The 'bit' is the index of the
* bit in the activation pattern and nSteps is the number of steps of
* prediction desired for that bit.
*/
Map<Tuple, BitHistory> activeBitHistory = new HashMap<Tuple, BitHistory>();
/**
* This keeps track of the actual value to use for each bucket index. We
* start with 1 bucket, no actual value so that the first infer has something
* to return
*/
List<?> actualValues = new ArrayList<Object>();
String g_debugPrefix = "CLAClassifier";
/**
* CLAClassifier no-arg constructor with defaults
*/
public CLAClassifier() {
this(new TIntArrayList(new int[] { 1 }), 0.001, 0.3, 0);
}
/**
* Constructor for the CLA classifier
*
* @param steps sequence of the different steps of multi-step predictions to learn
* @param alpha The alpha used to compute running averages of the bucket duty
cycles for each activation pattern bit. A lower alpha results
in longer term memory.
* @param actValueAlpha
* @param verbosity verbosity level, can be 0, 1, or 2
*/
public CLAClassifier(TIntList steps, double alpha, double actValueAlpha, int verbosity) {
this.steps = steps;
this.alpha = alpha;
this.actValueAlpha = actValueAlpha;
this.verbosity = verbosity;
actualValues.add(null);
patternNZHistory = new Deque<Tuple>(ArrayUtils.max(steps.toArray()) + 1);
}
/**
* Process one input sample.
* This method is called by outer loop code outside the nupic-engine. We
* use this instead of the nupic engine compute() because our inputs and
* outputs aren't fixed size vectors of reals.
*
* @param recordNum Record number of this input pattern. Record numbers should
* normally increase sequentially by 1 each time unless there
* are missing records in the dataset. Knowing this information
* insures that we don't get confused by missing records.
* @param classification {@link Map} of the classification information:
* bucketIdx: index of the encoder bucket
* actValue: actual value going into the encoder
* @param patternNZ list of the active indices from the output below
* @param learn if true, learn this sample
* @param infer if true, perform inference
*
* @return {@link Classification} containing inference results, there is one entry for each
* step in steps, where the key is the number of steps, and
* the value is an array containing the relative likelihood for
* each bucketIdx starting from bucketIdx 0.
*
* There is also an entry containing the average actual value to
* use for each bucket. The key is 'actualValues'.
*
* for example:
* {
* 1 : [0.1, 0.3, 0.2, 0.7],
* 4 : [0.2, 0.4, 0.3, 0.5],
* 'actualValues': [1.5, 3.5, 5.5, 7.6],
* }
*/
@SuppressWarnings("unchecked")
public <T> Classification<T> compute(int recordNum, Map<String, Object> classification, int[] patternNZ, boolean learn, boolean infer) {
Classification<T> retVal = new Classification<T>();
List<T> actualValues = (List<T>)this.actualValues;
// Save the offset between recordNum and learnIteration if this is the first
// compute
if(recordNumMinusLearnIteration == -1) {
recordNumMinusLearnIteration = recordNum - learnIteration;
}
// Update the learn iteration
learnIteration = recordNum - recordNumMinusLearnIteration;
if(verbosity >= 1) {
System.out.println(String.format("\n%s: compute ", g_debugPrefix));
System.out.println(" recordNum: " + recordNum);
System.out.println(" learnIteration: " + learnIteration);
System.out.println(String.format(" patternNZ(%d): ", patternNZ.length, patternNZ));
System.out.println(" classificationIn: " + classification);
}
patternNZHistory.append(new Tuple(learnIteration, patternNZ));
//------------------------------------------------------------------------
// Inference:
// For each active bit in the activationPattern, get the classification
// votes
//
// Return value dict. For buckets which we don't have an actual value
// for yet, just plug in any valid actual value. It doesn't matter what
// we use because that bucket won't have non-zero likelihood anyways.
if(infer) {
// NOTE: If doing 0-step prediction, we shouldn't use any knowledge
// of the classification input during inference.
Object defaultValue = null;
if(steps.get(0) == 0) {
defaultValue = 0;
}else{
defaultValue = classification.get("actValue");
}
T[] actValues = (T[])new Object[this.actualValues.size()];
for(int i = 0;i < actualValues.size();i++) {
actValues[i] = (T)(actualValues.get(i) == null ? defaultValue : actualValues.get(i));
}
retVal.setActualValues(actValues);
// For each n-step prediction...
for(int nSteps : steps.toArray()) {
// Accumulate bucket index votes and actValues into these arrays
double[] sumVotes = new double[maxBucketIdx + 1];
double[] bitVotes = new double[maxBucketIdx + 1];
for(int bit : patternNZ) {
Tuple key = new Tuple(bit, nSteps);
BitHistory history = activeBitHistory.get(key);
if(history == null) continue;
history.infer(learnIteration, bitVotes);
sumVotes = ArrayUtils.d_add(sumVotes, bitVotes);
}
// Return the votes for each bucket, normalized
double total = ArrayUtils.sum(sumVotes);
if(total > 0) {
sumVotes = ArrayUtils.divide(sumVotes, total);
}else{
// If all buckets have zero probability then simply make all of the
// buckets equally likely. There is no actual prediction for this
// timestep so any of the possible predictions are just as good.
if(sumVotes.length > 0) {
Arrays.fill(sumVotes, 1.0 / (double)sumVotes.length);
}
}
retVal.setStats(nSteps, sumVotes);
}
}
// ------------------------------------------------------------------------
// Learning:
// For each active bit in the activationPattern, store the classification
// info. If the bucketIdx is None, we can't learn. This can happen when the
// field is missing in a specific record.
if(learn && classification.get("bucketIdx") != null) {
// Get classification info
int bucketIdx = (int)classification.get("bucketIdx");
Object actValue = classification.get("actValue");
// Update maxBucketIndex
maxBucketIdx = (int) Math.max(maxBucketIdx, bucketIdx);
// Update rolling average of actual values if it's a scalar. If it's
// not, it must be a category, in which case each bucket only ever
// sees one category so we don't need a running average.
while(maxBucketIdx > actualValues.size() - 1) {
actualValues.add(null);
}
if(actualValues.get(bucketIdx) == null) {
actualValues.set(bucketIdx, (T)actValue);
}else{
if(Number.class.isAssignableFrom(actValue.getClass())) {
Double val = ((1.0 - actValueAlpha) * ((Number)actualValues.get(bucketIdx)).doubleValue() +
actValueAlpha * ((Number)actValue).doubleValue());
actualValues.set(bucketIdx, (T)val);
}else{
actualValues.set(bucketIdx, (T)actValue);
}
}
// Train each pattern that we have in our history that aligns with the
// steps we have in steps
int nSteps = -1;
int iteration = 0;
int[] learnPatternNZ = null;
for(int n : steps.toArray()) {
nSteps = n;
// Do we have the pattern that should be assigned to this classification
// in our pattern history? If not, skip it
boolean found = false;
for(Tuple t : patternNZHistory) {
iteration = (int)t.get(0);
learnPatternNZ = (int[]) t.get(1);
if(iteration == learnIteration - nSteps) {
found = true;
break;
}
iteration++;
}
if(!found) continue;
// Store classification info for each active bit from the pattern
// that we got nSteps time steps ago.
for(int bit : learnPatternNZ) {
// Get the history structure for this bit and step
Tuple key = new Tuple(bit, nSteps);
BitHistory history = activeBitHistory.get(key);
if(history == null) {
activeBitHistory.put(key, history = new BitHistory(this, bit, nSteps));
}
history.store(learnIteration, bucketIdx);
}
}
}
if(infer && verbosity >= 1) {
System.out.println(" inference: combined bucket likelihoods:");
System.out.println(" actual bucket values: " + Arrays.toString((T[])retVal.getActualValues()));
for(int key : retVal.stepSet()) {
if(retVal.getActualValue(key) == null) continue;
Object[] actual = new Object[] { (T)retVal.getActualValue(key) };
System.out.println(String.format(" %d steps: ", key, pFormatArray(actual)));
int bestBucketIdx = retVal.getMostProbableBucketIndex(key);
System.out.println(String.format(" most likely bucket idx: %d, value: %s ", bestBucketIdx,
retVal.getActualValue(bestBucketIdx)));
}
}
return retVal;
}
/**
* Return a string with pretty-print of an array using the given format
* for each element
*
* @param arr
* @return
*/
private <T> String pFormatArray(T[] arr) {
if(arr == null) return "";
StringBuilder sb = new StringBuilder("[ ");
for(T t : arr) {
sb.append(String.format("%.2s", t));
}
sb.append(" ]");
return sb.toString();
}
}