/* --------------------------------------------------------------------- * Numenta Platform for Intelligent Computing (NuPIC) * Copyright (C) 2014, Numenta, Inc. Unless you have an agreement * with Numenta, Inc., for a separate license for this software code, the * following terms and conditions apply: * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero Public License version 3 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU Affero Public License for more details. * * You should have received a copy of the GNU Affero Public License * along with this program. If not, see http://www.gnu.org/licenses. * * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ package org.numenta.nupic.encoders; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import gnu.trove.map.TIntObjectMap; import gnu.trove.map.TObjectIntMap; import gnu.trove.map.hash.TIntObjectHashMap; import gnu.trove.map.hash.TObjectIntHashMap; import org.numenta.nupic.Parameters; import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.MinMax; import org.numenta.nupic.util.SparseObjectMatrix; import org.numenta.nupic.util.Tuple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; /** * Encodes a list of discrete categories (described by strings), that aren't * related to each other, so we never emit a mixture of categories. * * The value of zero is reserved for "unknown category" * * Internally we use a ScalarEncoder with a radius of 1, but since we only encode * integers, we never get mixture outputs. * * The {@link SDRCategoryEncoder} uses a different method to encode categories. * * <P> * Typical usage is as follows: * <PRE> * CategoryEncoder.Builder builder = ((CategoryEncoder.Builder)CategoryEncoder.builder()) * .w(3) * .radius(0.0) * .minVal(0.0) * .maxVal(8.0) * .periodic(false) * .forced(true); * * CategoryEncoder encoder = builder.build(); * * <b>Above values are <i>not</i> an example of "sane" values.</b> * * </PRE> * * @author David Ray * @see ScalarEncoder * @see Encoder * @see Encoding * @see Parameters */ public class CategoryEncoder extends Encoder<String> { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(CategoryEncoder.class); protected int ncategories; protected TObjectIntMap<String> categoryToIndex = new TObjectIntHashMap<String>(); protected TIntObjectMap<String> indexToCategory = new TIntObjectHashMap<String>(); protected List<String> categoryList; protected int width; private ScalarEncoder scalarEncoder; /** * Constructs a new {@code CategoryEncoder} */ private CategoryEncoder() { } /** * Returns a builder for building CategoryEncoders. * This builder may be reused to produce multiple builders * * @return a {@code CategoryEncoder.Builder} */ public static Encoder.Builder<CategoryEncoder.Builder, CategoryEncoder> builder() { return new CategoryEncoder.Builder(); } public void init() { // number of categories includes zero'th category: "unknown" ncategories = categoryList == null ? 0 : categoryList.size() + 1; minVal = 0; maxVal = ncategories - 1; try { scalarEncoder = ScalarEncoder.builder() .n(this.n) .w(this.w) .radius(this.radius) .minVal(this.minVal) .maxVal(this.maxVal) .periodic(this.periodic) .forced(this.forced).build(); }catch(Exception e) { String msg = null; int idx = -1; if((idx = (msg = e.getMessage()).indexOf("ScalarEncoder")) != -1) { msg = msg.substring(0, idx).concat("CategoryEncoder"); throw new IllegalStateException(msg); } } indexToCategory.put(0, "<UNKNOWN>"); if(categoryList != null && !categoryList.isEmpty()) { int len = categoryList.size(); for(int i = 0;i < len;i++) { categoryToIndex.put(categoryList.get(i), i + 1); indexToCategory.put(i + 1, categoryList.get(i)); } } width = n = w * ncategories; //TODO this is what the CategoryEncoder was doing before I added the ScalarEncoder delegate. //I'm concerned because we're changing n without calling init again on the scalar encoder. //In other words, if I move the scalarEncoder = ...build() from to here, the test cases fail //which indicates significant fragility and at some level a violation of encapsulation. scalarEncoder.n = n; if(getWidth() != width) { throw new IllegalStateException( "Width != w (num bits to represent output item) * #categories"); } description.add(new Tuple(name, 0)); } /** * {@inheritDoc} */ @Override public <T> TDoubleList getScalars(T d) { return new TDoubleArrayList(new double[] { categoryToIndex.get(d) }); } /** * {@inheritDoc} */ @Override public int[] getBucketIndices(String input) { if(input == null) return null; return scalarEncoder.getBucketIndices(categoryToIndex.get(input)); } /** * {@inheritDoc} */ @Override public void encodeIntoArray(String input, int[] output) { String val = null; double value = 0; if(input == null) { val = "<missing>"; }else{ value = categoryToIndex.get(input); value = value == categoryToIndex.getNoEntryValue() ? 0 : value; scalarEncoder.encodeIntoArray(value, output); } LOG.trace("input: {}, val: {}, value: {}, output: {}", input, val, value, Arrays.toString(output)); } /** * {@inheritDoc} */ @Override public DecodeResult decode(int[] encoded, String parentFieldName) { // Get the scalar values from the underlying scalar encoder DecodeResult result = scalarEncoder.decode(encoded, parentFieldName); if(result.getFields().size() == 0) { return result; } // Expect only 1 field if(result.getFields().size() != 1) { throw new IllegalStateException("Expecting only one field"); } //Get the list of categories the scalar values correspond to and // generate the description from the category name(s). Map<String, RangeList> fieldRanges = result.getFields(); List<MinMax> outRanges = new ArrayList<MinMax>(); StringBuilder desc = new StringBuilder(); for(String descripStr : fieldRanges.keySet()) { MinMax minMax = fieldRanges.get(descripStr).getRange(0); int minV = (int)Math.round(minMax.min()); int maxV = (int)Math.round(minMax.max()); outRanges.add(new MinMax(minV, maxV)); while(minV <= maxV) { if(desc.length() > 0) { desc.append(", "); } desc.append(indexToCategory.get(minV)); minV += 1; } } //Return result String fieldName; if(!parentFieldName.isEmpty()) { fieldName = String.format("%s.%s", parentFieldName, name); }else{ fieldName = name; } Map<String, RangeList> retVal = new HashMap<String, RangeList>(); retVal.put(fieldName, new RangeList(outRanges, desc.toString())); return new DecodeResult(retVal, Arrays.asList(new String[] { fieldName })); } /** * {@inheritDoc} */ @Override public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues, boolean fractional) { double expValue = expValues.get(0); double actValue = actValues.get(0); double closeness = expValue == actValue ? 1.0 : 0; if(!fractional) closeness = 1.0 - closeness; return new TDoubleArrayList(new double[]{ closeness }); } /** * Returns a list of items, one for each bucket defined by this encoder. * Each item is the value assigned to that bucket, this is the same as the * EncoderResult.value that would be returned by getBucketInfo() for that * bucket and is in the same format as the input that would be passed to * encode(). * * This call is faster than calling getBucketInfo() on each bucket individually * if all you need are the bucket values. * * @param returnType class type parameter so that this method can return encoder * specific value types * * @return list of items, each item representing the bucket value for that * bucket. */ @SuppressWarnings("unchecked") @Override public <T> List<T> getBucketValues(Class<T> t) { if(bucketValues == null) { SparseObjectMatrix<int[]> topDownMapping = scalarEncoder.getTopDownMapping(); int numBuckets = topDownMapping.getMaxIndex() + 1; bucketValues = new ArrayList<String>(); for(int i = 0;i < numBuckets;i++) { ((List<String>)bucketValues).add((String)getBucketInfo(new int[] { i }).get(0).getValue()); } } return (List<T>)bucketValues; } /** * {@inheritDoc} */ @Override public List<Encoding> getBucketInfo(int[] buckets) { // For the category encoder, the bucket index is the category index List<Encoding> bucketInfo = scalarEncoder.getBucketInfo(buckets); int categoryIndex = (int)Math.round((double)bucketInfo.get(0).getValue()); String category = indexToCategory.get(categoryIndex); bucketInfo.set(0, new Encoding(category, categoryIndex, bucketInfo.get(0).getEncoding())); return bucketInfo; } /** * {@inheritDoc} */ @Override public List<Encoding> topDownCompute(int[] encoded) { //Get/generate the topDown mapping table SparseObjectMatrix<int[]> topDownMapping = scalarEncoder.getTopDownMapping(); // See which "category" we match the closest. int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded)); return getBucketInfo(new int[] { category }); } public List<String> getCategoryList() { return categoryList; } public void setCategoryList(List<String> categoryList) { this.categoryList = categoryList; } /** * Returns a {@link EncoderBuilder} for constructing {@link CategoryEncoder}s * * The base class architecture is put together in such a way where boilerplate * initialization can be kept to a minimum for implementing subclasses, while avoiding * the mistake-proneness of extremely long argument lists. * * @see ScalarEncoder.Builder#setStuff(int) */ public static class Builder extends Encoder.Builder<CategoryEncoder.Builder, CategoryEncoder> { private List<String> categoryList; private Builder() {} @Override public CategoryEncoder build() { //Must be instantiated so that super class can initialize //boilerplate variables. encoder = new CategoryEncoder(); //Call super class here super.build(); //////////////////////////////////////////////////////// // Implementing classes would do setting of specific // // vars here together with any sanity checking // //////////////////////////////////////////////////////// if(categoryList == null) { throw new IllegalStateException("Category List cannot be null"); } //Set CategoryEncoder specific field ((CategoryEncoder)encoder).setCategoryList(this.categoryList); //Call init ((CategoryEncoder)encoder).init(); return (CategoryEncoder)encoder; } /** * Never called - just here as an example of specialization for a specific * subclass of Encoder.Builder * * Example specific method!! * * @param stuff * @return */ public CategoryEncoder.Builder categoryList(List<String> categoryList) { this.categoryList = categoryList; return this; } } @Override public int getWidth() { return getN(); } @Override public boolean isDelta() { return false; } }