/* --------------------------------------------------------------------- * Numenta Platform for Intelligent Computing (NuPIC) * Copyright (C) 2014, Numenta, Inc. Unless you have an agreement * with Numenta, Inc., for a separate license for this software code, the * following terms and conditions apply: * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero Public License version 3 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU Affero Public License for more details. * * You should have received a copy of the GNU Affero Public License * along with this program. If not, see http://www.gnu.org/licenses. * * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ package org.numenta.nupic.encoders; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.numenta.nupic.FieldMetaType; import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.Condition; import org.numenta.nupic.util.MinMax; import org.numenta.nupic.util.SparseObjectMatrix; import org.numenta.nupic.util.Tuple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; /** * Encodes a list of discrete categories (described by strings), that aren't * related to each other. * Each encoding is an SDR in which w out of n bits are turned on. * <p/> * Unknown categories are encoded as a single * * @see Encoder * @see Encoding */ public class SDRCategoryEncoder extends Encoder<String> { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(SDRCategoryEncoder.class); private Random random; private int thresholdOverlap; private final SDRByCategoryMap sdrByCategory = new SDRByCategoryMap(); /** * Inner class for keeping Categories and SDRs in ordered way */ @SuppressWarnings("serial") private static final class SDRByCategoryMap extends LinkedHashMap<String, int[]> { public int[] getSdr(int index) { Map.Entry<String, int[]> entry = this.getEntry(index); if (entry == null) return null; return entry.getValue(); } public String getCategory(int index) { Map.Entry<String, int[]> entry = this.getEntry(index); if (entry == null) return null; return entry.getKey(); } public int getIndexByCategory(String category) { Set<String> categories = this.keySet(); int inx = 0; for (String s : categories) { if (s.equals(category)) { return inx; } inx++; } return 0; } private Map.Entry<String, int[]> getEntry(int i) { Set<Map.Entry<String, int[]>> entries = entrySet(); if (i < 0 || i > entries.size()) { throw new IllegalArgumentException("Index should be in following range:[0," + entries.size() + "]"); } int j = 0; for (Map.Entry<String, int[]> entry : entries) if (j++ == i) return entry; return null; } } /** * Returns a builder for building {@code SDRCategoryEncoder}s. * This is the only way to instantiate {@code SDRCategoryEncoder} * * @return a {@code SDRCategoryEncoder.Builder} */ public static SDRCategoryEncoder.Builder builder() { return new Builder(); } private SDRCategoryEncoder() { } /* Python mapping def __init__(self, n, w, categoryList = None, name="category", verbosity=0, encoderSeed=1, forced=False): */ private void init(int n, int w, List<String> categoryList, String name, int encoderSeed, boolean forced) { /*Python ref: n is total bits in output w is the number of bits that are turned on for each rep categoryList is a list of strings that define the categories. If "none" then categories will automatically be added as they are encountered. forced (default False) : if True, skip checks for parameters' settings; see encoders/scalar.py for details*/ this.n = n; this.w = w; this.encLearningEnabled = true; this.random = new Random(); if (encoderSeed != -1) { this.random.setSeed(encoderSeed); } if (!forced) { if (n / w < 2) { throw new IllegalArgumentException(String.format( "Number of ON bits in SDR (%d) must be much smaller than the output width (%d)", w, n)); } if (w < 21) { throw new IllegalArgumentException(String.format( "Number of bits in the SDR (%d) must be greater than 2, and should be >= 21, pass forced=True to init() to override this check", w)); } } /* #Calculate average overlap of SDRs for decoding #Density is fraction of bits on, and it is also the #probability that any individual bit is on. */ double density = (double)this.w / this.n; double averageOverlap = w * density; /* # We can do a better job of calculating the threshold. For now, just # something quick and dirty, which is the midway point between average # and full overlap. averageOverlap is always < w, so the threshold # is always < w. */ this.thresholdOverlap = (int)(averageOverlap + this.w) / 2; /* # 1.25 -- too sensitive for decode test, so make it less sensitive */ if (this.thresholdOverlap < this.w - 3) { this.thresholdOverlap = this.w - 3; } this.description.add(new Tuple(name, 0)); this.name = name; /* # Always include an 'unknown' category for # edge cases */ this.addCategory("<UNKNOWN>"); if (categoryList == null || categoryList.size() == 0) { this.setLearningEnabled(true); } else { this.setLearningEnabled(false); for (String category : categoryList) { this.addCategory(category); } } } /** * {@inheritDoc} */ @Override public int getWidth() { return this.getN(); } /** * {@inheritDoc} */ @Override public boolean isDelta() { return false; } /** * {@inheritDoc} */ @Override public void encodeIntoArray(String input, int[] output) { int index; if (input == null || input.isEmpty()) { Arrays.fill(output, 0); index = 0; } else { index = getBucketIndices(input)[0]; int[] categoryEncoding = sdrByCategory.getSdr(index); System.arraycopy(categoryEncoding, 0, output, 0, categoryEncoding.length); } LOG.trace("input:" + input + ", index:" + index + ", output:" + ArrayUtils.intArrayToString(output)); LOG.trace("decoded:" + decodedToStr(decode(output, ""))); } /** * {@inheritDoc} */ @Override public Set<FieldMetaType> getDecoderOutputFieldTypes() { return new HashSet<>(Arrays.asList(FieldMetaType.LIST, FieldMetaType.STRING)); } /** * {@inheritDoc} */ @Override public int[] getBucketIndices(String input) { return new int[]{(int)getScalars(input).get(0)}; } /** * {@inheritDoc} */ @Override public <S> TDoubleList getScalars(S input) { String inputCasted = (String)input; int index = 0; TDoubleList result = new TDoubleArrayList(); if (inputCasted == null || inputCasted.isEmpty()) { result.add(0); return result; } if (!sdrByCategory.containsKey(input)) { if (isEncoderLearningEnabled()) { index = sdrByCategory.size(); addCategory(inputCasted); } } else { index = sdrByCategory.getIndexByCategory(inputCasted); } result.add(index); return result; } /** * No parentFieldName parameter method overload for the {@link #decode(int[], String)}. * * @param encoded - bit array to be decoded * @return */ public DecodeResult decode(int[] encoded) { return decode(encoded, null); } /** * {@inheritDoc} */ @Override public DecodeResult decode(int[] encoded, String parentFieldName) { //assert (encoded[0:self.n] <= 1.0).all() assert ArrayUtils.all(encoded, new Condition.Adapter<Integer>() { @Override public boolean eval(int i) { return i <= 1; } }); //overlaps = (self.sdrs * encoded[0:self.n]).sum(axis=1) int[] overlap = new int[sdrByCategory.size()]; for (int i = 0; i < sdrByCategory.size(); i++) { int[] sdr = sdrByCategory.getSdr(i); for (int j = 0; j < sdr.length; j++) { if (sdr[j] == encoded[j] && encoded[j] == 1) { overlap[i]++; } } } LOG.trace("Overlaps for decoding:"); if (LOG.isTraceEnabled()){ int inx = 0; for (String category : sdrByCategory.keySet()) { LOG.trace(overlap[inx] + " " + category); inx++; } } //matchingCategories = (overlaps > self.thresholdOverlap).nonzero()[0] int[] matchingCategories = ArrayUtils.where(overlap, new Condition.Adapter<Integer>() { @Override public boolean eval(int overlaps) { return overlaps > thresholdOverlap; } }); StringBuilder resultString = new StringBuilder(); List<MinMax> resultRanges = new ArrayList<>(); String fieldName; for (int index : matchingCategories) { if (resultString.length() != 0) { resultString.append(" "); } resultString.append(sdrByCategory.getCategory(index)); resultRanges.add(new MinMax(index, index)); } if (parentFieldName == null || parentFieldName.isEmpty()) { fieldName = getName(); } else { fieldName = String.format("%s.%s", parentFieldName, getName()); } Map<String, RangeList> fieldsDict = new HashMap<>(); fieldsDict.put(fieldName, new RangeList(resultRanges, resultString.toString())); // return ({fieldName: (resultRanges, resultString)}, [fieldName]) return new DecodeResult(fieldsDict, Arrays.asList(fieldName)); } /** * {@inheritDoc} */ @Override public List<Encoding> topDownCompute(int[] encoded) { if (sdrByCategory.size() == 0) { return new ArrayList<>(); } //TODO the rightVecProd method belongs to SparseBinaryMatrix in Nupic Core, In python this method call stack: topDownCompute [sdrcategory.py:317]/rightVecProd [math.py:4474] -->return _math._SparseMatrix32_rightVecProd(self, *args) int categoryIndex = ArrayUtils.argmax(rightVecProd(getTopDownMapping(), encoded)); return getEncoderResultsByIndex(getTopDownMapping(), categoryIndex); } /** * {@inheritDoc} */ @Override public List<Encoding> getBucketInfo(int[] buckets) { if (sdrByCategory.size() == 0) { return new ArrayList<>(); } int categoryIndex = buckets[0]; return getEncoderResultsByIndex(getTopDownMapping(), categoryIndex); } /** * Return the internal topDownMapping matrix used for handling the * {@link #getBucketInfo(int[])} and {@link #topDownCompute(int[])} methods. This is a matrix, one row per * category (bucket) where each row contains the encoded output for that * category. * * @return {@link SparseObjectMatrix} */ public SparseObjectMatrix<int[]> getTopDownMapping() { if (topDownMapping == null) { topDownMapping = new SparseObjectMatrix<>( new int[]{sdrByCategory.size()}); int[] outputSpace = new int[getN()]; Set<String> categories = sdrByCategory.keySet(); int inx = 0; for (String category : categories) { encodeIntoArray(category, outputSpace); topDownMapping.set(inx, Arrays.copyOf(outputSpace, outputSpace.length)); inx++; } } return topDownMapping; } /** * {@inheritDoc} */ @SuppressWarnings("unchecked") @Override public <S> List<S> getBucketValues(Class<S> returnType) { return new ArrayList<>((Collection<S>)this.sdrByCategory.keySet()); } /** * Returns list of registered SDRs for this encoder * * @return {@link Collection} */ public Collection<int[]> getSDRs() { return Collections.unmodifiableCollection(sdrByCategory.values()); } private List<Encoding> getEncoderResultsByIndex(SparseObjectMatrix<int[]> topDownMapping, int categoryIndex) { List<Encoding> result = new ArrayList<>(); String category = sdrByCategory.getCategory(categoryIndex); int[] encoding = topDownMapping.getObject(categoryIndex); result.add(new Encoding(category, categoryIndex, encoding)); return result; } private void addCategory(String category) { if (this.sdrByCategory.containsKey(category)) { throw new IllegalArgumentException(String.format("Attempt to add encoder category '%s' that already exists", category)); } sdrByCategory.put(category, newRep()); //reset topDown mapping topDownMapping = null; } //replacement for Python sorted(self.random.sample(xrange(self.n), self.w)) private int[] getSortedSample(final int populationSize, final int sampleLength) { TIntSet resultSet = new TIntHashSet(); while (resultSet.size() < sampleLength) { resultSet.add(random.nextInt(populationSize)); } int[] result = resultSet.toArray(); Arrays.sort(result); return result; } private int[] newRep() { int maxAttempts = 1000; boolean foundUnique = true; int[] oneBits; int sdr[] = new int[n]; for (int index = 0; index < maxAttempts; index++) { foundUnique = true; oneBits = getSortedSample(n, w); sdr = new int[n]; for (int i = 0; i < oneBits.length; i++) { int oneBitInx = oneBits[i]; sdr[oneBitInx] = 1; } for (int[] existingSdr : this.sdrByCategory.values()) { if (Arrays.equals(sdr, existingSdr)) { foundUnique = false; break; } } if (foundUnique) { break; } } if (!foundUnique) { throw new RuntimeException(String.format("Error, could not find unique pattern %d after %d attempts", sdrByCategory.size(), maxAttempts)); } return sdr; } /** * Builder class for {@code SDRCategoryEncoder} * <p>N is total bits in output</p> * <p>W is the number of bits that are turned on for each rep</p> * <p>categoryList is a list of strings that define the categories.If no categories provided, then they will automatically be added as they are encountered.</p> * <p>forced (default false) : if true, skip checks for parameters settings</p> */ public static final class Builder extends Encoder.Builder<Builder, SDRCategoryEncoder> { private List<String> categoryList = new ArrayList<>(); private int encoderSeed = 1; @Override public SDRCategoryEncoder build() { if (n == 0) { throw new IllegalStateException("\"N\" should be set"); } if (w == 0) { throw new IllegalStateException("\"W\" should be set"); } if(categoryList == null) { throw new IllegalStateException("Category List cannot be null"); } SDRCategoryEncoder sdrCategoryEncoder = new SDRCategoryEncoder(); sdrCategoryEncoder.init(n, w, categoryList, name, encoderSeed, forced); return sdrCategoryEncoder; } public Builder categoryList(List<String> categoryList) { this.categoryList = categoryList; return this; } public Builder encoderSeed(int encoderSeed) { this.encoderSeed = encoderSeed; return this; } @Override public Builder radius(double radius) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } @Override public Builder resolution(double resolution) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } @Override public Builder periodic(boolean periodic) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } @Override public Builder clipInput(boolean clipInput) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } @Override public Builder maxVal(double maxVal) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } @Override public Builder minVal(double minVal) { throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder"); } } }