// Stanford Classifier - a multiclass maxent classifier // LinearClassifier // Copyright (c) 2003-2007 The Board of Trustees of // The Leland Stanford Junior University. All Rights Reserved. // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 2 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. // // For more information, bug reports, fixes, contact: // Christopher Manning // Dept of Computer Science, Gates 1A // Stanford CA 94305-9010 // USA // Support/Questions: java-nlp-user@lists.stanford.edu // Licensing: java-nlp-support@lists.stanford.edu // http://www-nlp.stanford.edu/software/classifier.shtml package edu.stanford.nlp.classify; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.BasicDatum; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.util.*; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Distribution; import edu.stanford.nlp.stats.Counters; import java.io.*; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; import java.util.function.Function; import edu.stanford.nlp.util.logging.Redwood; /** * Implements a multiclass linear classifier. At classification time this * can be any generalized linear model classifier (such as a perceptron, * a maxent classifier (softmax logistic regression), or an SVM). * * @author Dan Klein * @author Jenny Finkel * @author Galen Andrew (converted to arrays and indices) * @author Christopher Manning (most of the printing options) * @author Eric Yeh (save to text file, new constructor w/thresholds) * @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization) * @author {@literal nmramesh@cs.stanford.edu} {@link #weightsAsMapOfCounters()} * @author Angel Chang (Add functions to get top features, and number of features with weights above a certain threshold) * * @param <L> The type of the labels in the Classifier * @param <F> The type of the features in the Classifier */ public class LinearClassifier<L, F> implements ProbabilisticClassifier<L, F>, RVFClassifier<L, F> { /** A logger for this class */ private static final Redwood.RedwoodChannels logger = Redwood.channels(LinearClassifier.class); /** Classifier weights. First index is the featureIndex value and second index is the labelIndex value. */ private double[][] weights; private Index<L> labelIndex; private Index<F> featureIndex; public boolean intern = false; // variable should be deleted when breaking serialization anyway.... private double[] thresholds; // = null; private static final long serialVersionUID = 8499574525453275255L; private static final int MAX_FEATURE_ALIGN_WIDTH = 50; public static final String TEXT_SERIALIZATION_DELIMITER = "\t"; @Override public Collection<L> labels() { return labelIndex.objectsList(); } public Collection<F> features() { return featureIndex.objectsList(); } public Index<L> labelIndex() { return labelIndex; } public Index<F> featureIndex() { return featureIndex; } private double weight(int iFeature, int iLabel) { if (iFeature < 0) { //logger.info("feature not seen "); return 0.0; } assert iFeature < weights.length; assert iLabel < weights[iFeature].length; return weights[iFeature][iLabel]; } private double weight(F feature, int iLabel) { int f = featureIndex.indexOf(feature); return weight(f, iLabel); } public double weight(F feature, L label) { int f = featureIndex.indexOf(feature); int iLabel = labelIndex.indexOf(label); return weight(f, iLabel); } /* --- obsolete method from before this class was rewritten using arrays public Counter scoresOf(Datum example) { Counter scores = new Counter(); for (L l : labels()) { scores.setCount(l, scoreOf(example, l)); } return scores; } --- */ /** Construct a counter with keys the labels of the classifier and * values the score (unnormalized log probability) of each class. */ @Override public Counter<L> scoresOf(Datum<L, F> example) { if(example instanceof RVFDatum<?, ?>)return scoresOfRVFDatum((RVFDatum<L,F>)example); Collection<F> feats = example.asFeatures(); int[] features = new int[feats.size()]; int i = 0; for (F f : feats) { int index = featureIndex.indexOf(f); if (index >= 0) { features[i++] = index; // } else { //logger.info("FEATURE LESS THAN ZERO: " + f); } } int[] activeFeatures = new int[i]; synchronized (System.class) { System.arraycopy(features, 0, activeFeatures, 0, i); } Counter<L> scores = new ClassicCounter<>(); for (L lab : labels()) { scores.setCount(lab, scoreOf(activeFeatures, lab)); } return scores; } /** Given a datum's features, construct a counter with keys * the labels and values the score (unnormalized log probability) * for each class. */ public Counter<L> scoresOf(int[] features) { Counter<L> scores = new ClassicCounter<>(); for (L label : labels()) scores.setCount(label, scoreOf(features, label)); return scores; } /** Returns of the score of the Datum for the specified label. * Ignores the true label of the Datum. */ public double scoreOf(Datum<L, F> example, L label) { if (example instanceof RVFDatum<?, ?>) { return scoreOfRVFDatum((RVFDatum<L,F>)example, label); } int iLabel = labelIndex.indexOf(label); double score = 0.0; for (F f : example.asFeatures()) { score += weight(f, iLabel); } return score + thresholds[iLabel]; } /** Construct a counter with keys the labels of the classifier and * values the score (unnormalized log probability) of each class * for an RVFDatum. */ @Override @Deprecated public Counter<L> scoresOf(RVFDatum<L, F> example) { Counter<L> scores = new ClassicCounter<>(); for (L l : labels()) { scores.setCount(l, scoreOfRVFDatum(example, l)); } //System.out.println("Scores are: " + scores + " (gold: " + example.label() + ")"); return scores; } /** Construct a counter with keys the labels of the classifier and * values the score (unnormalized log probability) of each class * for an RVFDatum. */ private Counter<L> scoresOfRVFDatum(RVFDatum<L, F> example) { Counter<L> scores = new ClassicCounter<>(); // Index the features in the datum Counter<F> asCounter = example.asFeaturesCounter(); Counter<Integer> asIndexedCounter = new ClassicCounter<>(asCounter.size()); for (Map.Entry<F, Double> entry : asCounter.entrySet()) { asIndexedCounter.setCount(featureIndex.indexOf(entry.getKey()), entry.getValue()); } // Set the scores appropriately for (L l : labels()) { scores.setCount(l, scoreOfRVFDatum(asIndexedCounter, l)); } //System.out.println("Scores are: " + scores + " (gold: " + example.label() + ")"); return scores; } /** Returns the score of the RVFDatum for the specified label. * Ignores the true label of the RVFDatum. * * @param example Used to get the observed x value. Its label is ignored. * @param label The label y that the observed value is scored with. * @return A linear classifier score */ private double scoreOfRVFDatum(RVFDatum<L, F> example, L label) { int iLabel = labelIndex.indexOf(label); double score = 0.0; Counter<F> features = example.asFeaturesCounter(); for (Map.Entry<F, Double> entry : features.entrySet()) { score += weight(entry.getKey(), iLabel) * entry.getValue(); } return score + thresholds[iLabel]; } /** Returns the score of the RVFDatum for the specified label. * Ignores the true label of the RVFDatum. */ private double scoreOfRVFDatum(Counter<Integer> features, L label) { int iLabel = labelIndex.indexOf(label); double score = 0.0; for (Map.Entry<Integer, Double> entry : features.entrySet()) { score += weight(entry.getKey(), iLabel) * entry.getValue(); } return score + thresholds[iLabel]; } /** Returns of the score of the Datum as internalized features for the * specified label. Ignores the true label of the Datum. * Doesn't consider a value for each feature. */ private double scoreOf(int[] feats, L label) { int iLabel = labelIndex.indexOf(label); assert iLabel >= 0; double score = 0.0; for (int feat : feats) { score += weight(feat, iLabel); } return score + thresholds[iLabel]; } /** * Returns a counter mapping from each class name to the probability of * that class for a certain example. * Looking at the the sum of each count v, should be 1.0. */ @Override public Counter<L> probabilityOf(Datum<L, F> example) { if(example instanceof RVFDatum<?, ?>)return probabilityOfRVFDatum((RVFDatum<L,F>)example); Counter<L> scores = logProbabilityOf(example); for (L label : scores.keySet()) { scores.setCount(label, Math.exp(scores.getCount(label))); } return scores; } /** * Returns a counter mapping from each class name to the probability of * that class for a certain example. * Looking at the the sum of each count v, should be 1.0. */ private Counter<L> probabilityOfRVFDatum(RVFDatum<L, F> example) { // NB: this duplicate method is needed so it calls the scoresOf method // with a RVFDatum signature Counter<L> scores = logProbabilityOfRVFDatum(example); for (L label : scores.keySet()) { scores.setCount(label, Math.exp(scores.getCount(label))); } return scores; } /** * Returns a counter mapping from each class name to the probability of * that class for a certain example. * Looking at the the sum of each count v, should be 1.0. */ @Deprecated public Counter<L> probabilityOf(RVFDatum<L, F> example) { // NB: this duplicate method is needed so it calls the scoresOf method // with a RVFDatum signature Counter<L> scores = logProbabilityOf(example); for (L label : scores.keySet()) { scores.setCount(label, Math.exp(scores.getCount(label))); } return scores; } /** * Returns a counter mapping from each class name to the log probability of * that class for a certain example. * Looking at the the sum of e^v for each count v, should be 1.0. */ @Override public Counter<L> logProbabilityOf(Datum<L, F> example) { if(example instanceof RVFDatum<?, ?>)return logProbabilityOfRVFDatum((RVFDatum<L,F>)example); Counter<L> scores = scoresOf(example); Counters.logNormalizeInPlace(scores); return scores; } /** * Given a datum's features, returns a counter mapping from each * class name to the log probability of that class. * Looking at the the sum of e^v for each count v, should be 1. */ public Counter<L> logProbabilityOf(int[] features) { Counter<L> scores = scoresOf(features); Counters.logNormalizeInPlace(scores); return scores; } public Counter<L> probabilityOf(int [] features) { Counter<L> scores = logProbabilityOf(features); for (L label : scores.keySet()) { scores.setCount(label, Math.exp(scores.getCount(label))); } return scores; } /** * Returns a counter for the log probability of each of the classes * looking at the the sum of e^v for each count v, should be 1 */ private Counter<L> logProbabilityOfRVFDatum(RVFDatum<L, F> example) { // NB: this duplicate method is needed so it calls the scoresOf method // with an RVFDatum signature!! Don't remove it! // JLS: type resolution of method parameters is static Counter<L> scores = scoresOfRVFDatum(example); Counters.logNormalizeInPlace(scores); return scores; } /** * Returns a counter for the log probability of each of the classes. * Looking at the the sum of e^v for each count v, should give 1. */ @Deprecated public Counter<L> logProbabilityOf(RVFDatum<L, F> example) { // NB: this duplicate method is needed so it calls the scoresOf method // with an RVFDatum signature!! Don't remove it! // JLS: type resolution of method parameters is static Counter<L> scores = scoresOf(example); Counters.logNormalizeInPlace(scores); return scores; } /** * Returns indices of labels * @param labels - Set of labels to get indices * @return Set of indices */ protected Set<Integer> getLabelIndices(Set<L> labels) { Set<Integer> iLabels = Generics.newHashSet(); for (L label:labels) { int iLabel = labelIndex.indexOf(label); iLabels.add(iLabel); if (iLabel < 0) throw new IllegalArgumentException("Unknown label " + label); } return iLabels; } /** * Returns number of features with weight above a certain threshold * (across all labels). * * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @return number of features satisfying the specified conditions */ public int getFeatureCount(double threshold, boolean useMagnitude) { int n = 0; for (double[] weightArray : weights) { for (double weight : weightArray) { double thisWeight = (useMagnitude) ? Math.abs(weight) : weight; if (thisWeight > threshold) { n++; } } } return n; } /** * Returns number of features with weight above a certain threshold. * * @param labels Set of labels we care about when counting features * Use null to get counts across all labels * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @return number of features satisfying the specified conditions */ public int getFeatureCount(Set<L> labels, double threshold, boolean useMagnitude) { if (labels != null) { Set<Integer> iLabels = getLabelIndices(labels); return getFeatureCountLabelIndices(iLabels, threshold, useMagnitude); } else { return getFeatureCount(threshold, useMagnitude); } } /** * Returns number of features with weight above a certain threshold. * * @param iLabels Set of label indices we care about when counting features * Use null to get counts across all labels * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @return number of features satisfying the specified conditions */ protected int getFeatureCountLabelIndices(Set<Integer> iLabels, double threshold, boolean useMagnitude) { int n = 0; for (double[] weightArray : weights) { for (int labIndex : iLabels) { double thisWeight = (useMagnitude) ? Math.abs(weightArray[labIndex]) : weightArray[labIndex]; if (thisWeight > threshold) { n++; } } } return n; } /** * Returns list of top features with weight above a certain threshold * (list is descending and across all labels). * * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @param numFeatures How many top features to return (-1 for unlimited) * @return List of triples indicating feature, label, weight */ public List<Triple<F,L,Double>> getTopFeatures(double threshold, boolean useMagnitude, int numFeatures) { return getTopFeatures(null, threshold, useMagnitude, numFeatures, true); } /** * Returns list of top features with weight above a certain threshold * @param labels Set of labels we care about when getting features * Use null to get features across all labels * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @param numFeatures How many top features to return (-1 for unlimited) * @param descending Return weights in descending order * @return List of triples indicating feature, label, weight */ public List<Triple<F,L,Double>> getTopFeatures(Set<L> labels, double threshold, boolean useMagnitude, int numFeatures, boolean descending) { if (labels != null) { Set<Integer> iLabels = getLabelIndices(labels); return getTopFeaturesLabelIndices(iLabels, threshold, useMagnitude, numFeatures, descending); } else { return getTopFeaturesLabelIndices(null, threshold, useMagnitude, numFeatures, descending); } } /** * Returns list of top features with weight above a certain threshold * @param iLabels Set of label indices we care about when getting features * Use null to get features across all labels * @param threshold Threshold above which we will count the feature * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @param numFeatures How many top features to return (-1 for unlimited) * @param descending Return weights in descending order * @return List of triples indicating feature, label, weight */ protected List<Triple<F,L,Double>> getTopFeaturesLabelIndices(Set<Integer> iLabels, double threshold, boolean useMagnitude, int numFeatures, boolean descending) { edu.stanford.nlp.util.PriorityQueue<Pair<Integer,Integer>> biggestKeys = new FixedPrioritiesPriorityQueue<>(); // locate biggest keys for (int feat = 0; feat < weights.length; feat++) { for (int lab = 0; lab < weights[feat].length; lab++) { if (iLabels != null && !iLabels.contains(lab)) { continue; } double thisWeight; if (useMagnitude) { thisWeight = Math.abs(weights[feat][lab]); } else { thisWeight = weights[feat][lab]; } if (thisWeight > threshold) { // reverse the weight, so get smallest first thisWeight = -thisWeight; if (biggestKeys.size() == numFeatures) { // have enough features, add only if bigger double lowest = biggestKeys.getPriority(); if (thisWeight < lowest) { // remove smallest biggestKeys.removeFirst(); biggestKeys.add(new Pair<>(feat, lab), thisWeight); } } else { // always add it if don't have enough features yet biggestKeys.add(new Pair<>(feat, lab), thisWeight); } } } } List<Triple<F,L,Double>> topFeatures = new ArrayList<>(biggestKeys.size()); while (!biggestKeys.isEmpty()) { Pair<Integer,Integer> p = biggestKeys.removeFirst(); double weight = weights[p.first()][p.second()]; F feat = featureIndex.get(p.first()); L label = labelIndex.get(p.second()); topFeatures.add(new Triple<>(feat, label, weight)); } if (descending) { Collections.reverse(topFeatures); } return topFeatures; } /** * Returns string representation of a list of top features * @param topFeatures List of triples indicating feature, label, weight * @return String representation of the list of features */ public String topFeaturesToString(List<Triple<F,L,Double>> topFeatures) { // find longest key length (for pretty printing) with a limit int maxLeng = 0; for (Triple<F,L,Double> t : topFeatures) { String key = "(" + t.first + "," + t.second + ")"; int leng = key.length(); if (leng > maxLeng) { maxLeng = leng; } } maxLeng = Math.min(64, maxLeng); // set up pretty printing of weights NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMinimumFractionDigits(4); nf.setMaximumFractionDigits(4); if (nf instanceof DecimalFormat) { ((DecimalFormat) nf).setPositivePrefix(" "); } //print high weight features to a String StringBuilder sb = new StringBuilder(); for (Triple<F,L,Double> t : topFeatures) { String key = "(" + t.first + "," + t.second + ")"; sb.append(StringUtils.pad(key, maxLeng)); sb.append(" "); double cnt = t.third(); if (Double.isInfinite(cnt)) { sb.append(cnt); } else { sb.append(nf.format(cnt)); } sb.append("\n"); } return sb.toString(); } /** Return a String that prints features with large weights. * * @param useMagnitude Whether the notion of "large" should ignore * the sign of the feature weight. * @param numFeatures How many top features to print * @param printDescending Print weights in descending order * @return The String representation of features with large weights */ public String toBiggestWeightFeaturesString(boolean useMagnitude, int numFeatures, boolean printDescending) { // this used to try to use a TreeSet, but that was WRONG.... edu.stanford.nlp.util.PriorityQueue<Pair<Integer,Integer>> biggestKeys = new FixedPrioritiesPriorityQueue<>(); // locate biggest keys for (int feat = 0; feat < weights.length; feat++) { for (int lab = 0; lab < weights[feat].length; lab++) { double thisWeight; // reverse the weight, so get smallest first if (useMagnitude) { thisWeight = -Math.abs(weights[feat][lab]); } else { thisWeight = -weights[feat][lab]; } if (biggestKeys.size() == numFeatures) { // have enough features, add only if bigger double lowest = biggestKeys.getPriority(); if (thisWeight < lowest) { // remove smallest biggestKeys.removeFirst(); biggestKeys.add(new Pair<>(feat, lab), thisWeight); } } else { // always add it if don't have enough features yet biggestKeys.add(new Pair<>(feat, lab), thisWeight); } } } // Put in List either reversed or not // (Note: can't repeatedly iterate over PriorityQueue.) int actualSize = biggestKeys.size(); Pair<Integer, Integer>[] bigArray = ErasureUtils.mkTArray(Pair.class, actualSize); // logger.info("biggestKeys is " + biggestKeys); if (printDescending) { for (int j = actualSize - 1; j >= 0; j--) { bigArray[j] = biggestKeys.removeFirst(); } } else { for (int j = 0; j < actualSize; j--) { bigArray[j] = biggestKeys.removeFirst(); } } List<Pair<Integer, Integer>> bigColl = Arrays.asList(bigArray); // logger.info("bigColl is " + bigColl); // find longest key length (for pretty printing) with a limit int maxLeng = 0; for (Pair<Integer,Integer> p : bigColl) { String key = "(" + featureIndex.get(p.first) + "," + labelIndex.get(p.second) + ")"; int leng = key.length(); if (leng > maxLeng) { maxLeng = leng; } } maxLeng = Math.min(64, maxLeng); // set up pretty printing of weights NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMinimumFractionDigits(4); nf.setMaximumFractionDigits(4); if (nf instanceof DecimalFormat) { ((DecimalFormat) nf).setPositivePrefix(" "); } //print high weight features to a String StringBuilder sb = new StringBuilder("LinearClassifier [printing top " + numFeatures + " features]\n"); for (Pair<Integer, Integer> p : bigColl) { String key = "(" + featureIndex.get(p.first) + "," + labelIndex.get(p.second) + ")"; sb.append(StringUtils.pad(key, maxLeng)); sb.append(" "); double cnt = weights[p.first][p.second]; if (Double.isInfinite(cnt)) { sb.append(cnt); } else { sb.append(nf.format(cnt)); } sb.append("\n"); } return sb.toString(); } /** * Similar to histogram but exact values of the weights * to see whether there are many equal weights. * * @return A human readable string about the classifier distribution. */ public String toDistributionString(int threshold) { Counter<Double> weightCounts = new ClassicCounter<>(); StringBuilder s = new StringBuilder(); s.append("Total number of weights: ").append(totalSize()); for (double[] weightArray : weights) { for (double weight : weightArray) { weightCounts.incrementCount(weight); } } s.append("Counts of weights\n"); Set<Double> keys = Counters.keysAbove(weightCounts, threshold); s.append(keys.size()).append(" keys occur more than ").append(threshold).append(" times "); return s.toString(); } public int totalSize() { return labelIndex.size() * featureIndex.size(); } public String toHistogramString() { // big classifiers double[][] hist = new double[3][202]; Object[][] histEg = new Object[3][202]; int num = 0; int pos = 0; int neg = 0; int zero = 0; double total = 0.0; double x2total = 0.0; double max = 0.0, min = 0.0; for (int f = 0; f < weights.length; f++) { for (int l = 0; l < weights[f].length; l++) { Pair<F, L> feat = new Pair<>(featureIndex.get(f), labelIndex.get(l)); num++; double wt = weights[f][l]; total += wt; x2total += wt * wt; if (wt > max) { max = wt; } if (wt < min) { min = wt; } if (wt < 0.0) { neg++; } else if (wt > 0.0) { pos++; } else { zero++; } int index; index = bucketizeValue(wt); hist[0][index]++; if (histEg[0][index] == null) { histEg[0][index] = feat; } if (wt < 0.1 && wt >= -0.1) { index = bucketizeValue(wt * 100.0); hist[1][index]++; if (histEg[1][index] == null) { histEg[1][index] = feat; } if (wt < 0.001 && wt >= -0.001) { index = bucketizeValue(wt * 10000.0); hist[2][index]++; if (histEg[2][index] == null) { histEg[2][index] = feat; } } } } } double ave = total / num; double stddev = (x2total / num) - ave * ave; StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); pw.println("Linear classifier with " + num + " f(x,y) features"); pw.println("Average weight: " + ave + "; std dev: " + stddev); pw.println("Max weight: " + max + " min weight: " + min); pw.println("Weights: " + neg + " negative; " + pos + " positive; " + zero + " zero."); printHistCounts(0, "Counts of lambda parameters between [-10, 10)", pw, hist, histEg); printHistCounts(1, "Closeup view of [-0.1, 0.1) depicted * 10^2", pw, hist, histEg); printHistCounts(2, "Closeup view of [-0.001, 0.001) depicted * 10^4", pw, hist, histEg); pw.close(); return sw.toString(); } /** Print out a partial representation of a linear classifier. * This just calls toString("WeightHistogram", 0) */ @Override public String toString() { return toString("WeightHistogram", 0); } /** * Print out a partial representation of a linear classifier in one of * several ways. * * @param style Options are: * HighWeight: print out the param parameters with largest weights; * HighMagnitude: print out the param parameters for which the absolute * value of their weight is largest; * AllWeights: print out the weights of all features; * WeightHistogram: print out a particular hard-coded textual histogram * representation of a classifier; * WeightDistribution; * * @param param Determines the number of things printed in certain styles * @throws IllegalArgumentException if the style name is unrecognized */ public String toString(String style, int param) { if (style == null || style.isEmpty()) { return "LinearClassifier with " + featureIndex.size() + " features, " + labelIndex.size() + " classes, and " + labelIndex.size() * featureIndex.size() + " parameters.\n"; } else if (style.equalsIgnoreCase("HighWeight")) { return toBiggestWeightFeaturesString(false, param, true); } else if (style.equalsIgnoreCase("HighMagnitude")) { return toBiggestWeightFeaturesString(true, param, true); } else if (style.equalsIgnoreCase("AllWeights")) { return toAllWeightsString(); } else if (style.equalsIgnoreCase("WeightHistogram")) { return toHistogramString(); } else if (style.equalsIgnoreCase("WeightDistribution")) { return toDistributionString(param); } else { throw new IllegalArgumentException("Unknown style: " + style); } } /** * Convert parameter value into number between 0 and 201 */ private static int bucketizeValue(double wt) { int index; if (wt >= 0.0) { index = ((int) (wt * 10.0)) + 100; } else { index = ((int) (Math.floor(wt * 10.0))) + 100; } if (index < 0) { index = 201; } else if (index > 200) { index = 200; } return index; } /** * Print histogram counts from hist and examples over a certain range */ private static void printHistCounts(int ind, String title, PrintWriter pw, double[][] hist, Object[][] histEg) { pw.println(title); for (int i = 0; i < 200; i++) { int intPart, fracPart; if (i < 100) { intPart = 10 - ((i + 9) / 10); fracPart = (10 - (i % 10)) % 10; } else { intPart = (i / 10) - 10; fracPart = i % 10; } pw.print("[" + ((i < 100) ? "-" : "") + intPart + "." + fracPart + ", " + ((i < 100) ? "-" : "") + intPart + "." + fracPart + "+0.1): " + hist[ind][i]); if (histEg[ind][i] != null) { pw.print(" [" + histEg[ind][i] + ((hist[ind][i] > 1) ? ", ..." : "") + "]"); } pw.println(); } } //TODO: Sort of assumes that Labels are Strings... public String toAllWeightsString() { StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); pw.println("Linear classifier with the following weights"); Datum<L, F> allFeatures = new BasicDatum<>(features(), (L) null); justificationOf(allFeatures, pw); return sw.toString(); } /** * Print all features in the classifier and the weight that they assign * to each class. Print to stderr. */ public void dump() { Datum<L, F> allFeatures = new BasicDatum<>(features(), (L) null); justificationOf(allFeatures); } /** * Print all features in the classifier and the weight that they assign * to each class. Print to the given PrintWriter. */ public void dump(PrintWriter pw) { Datum<L, F> allFeatures = new BasicDatum<>(features(), (L) null); justificationOf(allFeatures, pw); } /** * Print all features in the classifier and the weight that they assign * to each class. The feature names are printed in sorted order. */ public void dumpSorted() { Datum<L, F> allFeatures = new BasicDatum<>(features(), (L) null); justificationOf(allFeatures, new PrintWriter(System.err, true), true); } /** * Print all features active for a particular datum and the weight that * the classifier assigns to each class for those features. */ private void justificationOfRVFDatum(RVFDatum<L, F> example, PrintWriter pw) { int featureLength = 0; int labelLength = 6; NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMinimumFractionDigits(2); nf.setMaximumFractionDigits(2); if (nf instanceof DecimalFormat) { ((DecimalFormat) nf).setPositivePrefix(" "); } Counter<F> features = example.asFeaturesCounter(); for (F f : features.keySet()) { featureLength = Math.max(featureLength, f.toString().length() + 2 + nf.format(features.getCount(f)).length()); } // make as wide as total printout featureLength = Math.max(featureLength, "Total:".length()); // don't make it ridiculously wide featureLength = Math.min(featureLength, MAX_FEATURE_ALIGN_WIDTH); for (L l : labels()) { labelLength = Math.max(labelLength, l.toString().length()); } StringBuilder header = new StringBuilder(); for (int s = 0; s < featureLength; s++) { header.append(' '); } for (L l : labels()) { header.append(' '); header.append(StringUtils.pad(l, labelLength)); } pw.println(header); for (F f : features.keySet()) { String fStr = f.toString(); StringBuilder line = new StringBuilder(fStr); line.append('[').append(nf.format(features.getCount(f))).append(']'); fStr = line.toString(); for (int s = fStr.length(); s < featureLength; s++) { line.append(' '); } for (L l : labels()) { String lStr = nf.format(weight(f, l)); line.append(' '); line.append(lStr); for (int s = lStr.length(); s < labelLength; s++) { line.append(' '); } } pw.println(line); } Counter<L> scores = scoresOfRVFDatum(example); StringBuilder footer = new StringBuilder("Total:"); for (int s = footer.length(); s < featureLength; s++) { footer.append(' '); } for (L l : labels()) { footer.append(' '); String str = nf.format(scores.getCount(l)); footer.append(str); for (int s = str.length(); s < labelLength; s++) { footer.append(' '); } } pw.println(footer); Distribution<L> distr = Distribution.distributionFromLogisticCounter(scores); footer = new StringBuilder("Prob:"); for (int s = footer.length(); s < featureLength; s++) { footer.append(' '); } for (L l : labels()) { footer.append(' '); String str = nf.format(distr.getCount(l)); footer.append(str); for (int s = str.length(); s < labelLength; s++) { footer.append(' '); } } pw.println(footer); } public void justificationOf(Datum<L, F> example) { PrintWriter pw = new PrintWriter(System.err, true); justificationOf(example, pw); } /** * Print all features active for a particular datum and the weight that * the classifier assigns to each class for those features. */ public void justificationOf(Datum<L, F> example, PrintWriter pw) { justificationOf(example, pw, null); } /** * Print all features active for a particular datum and the weight that * the classifier assigns to each class for those features. Sorts by feature * name if 'sorted' is true. */ public void justificationOf(Datum<L, F> example, PrintWriter pw, boolean sorted) { if(example instanceof RVFDatum<?, ?>) justificationOf(example, pw, null, sorted); } public <T> void justificationOf(Datum<L, F> example, PrintWriter pw, Function<F, T> printer) { justificationOf(example, pw, printer, false); } /** Print all features active for a particular datum and the weight that * the classifier assigns to each class for those features. * * @param example The datum for which features are to be printed * @param pw Where to print it to * @param printer If this is non-null, then it is applied to each * feature to convert it to a more readable form * @param sortedByFeature Whether to sort by feature names */ public <T> void justificationOf(Datum<L, F> example, PrintWriter pw, Function<F, T> printer, boolean sortedByFeature) { if(example instanceof RVFDatum<?, ?>) { justificationOfRVFDatum((RVFDatum<L,F>)example,pw); return; } NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMinimumFractionDigits(2); nf.setMaximumFractionDigits(2); if (nf instanceof DecimalFormat) { ((DecimalFormat) nf).setPositivePrefix(" "); } // determine width for features, making it at least total's width int featureLength = 0; //TODO: not really sure what this Printer is supposed to spit out... for (F f : example.asFeatures()) { int length = f.toString().length(); if (printer != null) { length = printer.apply(f).toString().length(); } featureLength = Math.max(featureLength, length); } // make as wide as total printout featureLength = Math.max(featureLength, "Total:".length()); // don't make it ridiculously wide featureLength = Math.min(featureLength, MAX_FEATURE_ALIGN_WIDTH); // determine width for labels int labelLength = 6; for (L l : labels()) { labelLength = Math.max(labelLength, l.toString().length()); } // print header row of output listing classes StringBuilder header = new StringBuilder(""); for (int s = 0; s < featureLength; s++) { header.append(' '); } for (L l : labels()) { header.append(' '); header.append(StringUtils.pad(l, labelLength)); } pw.println(header); // print active features and weights per class Collection<F> featColl = example.asFeatures(); if (sortedByFeature){ featColl = ErasureUtils.sortedIfPossible(featColl); } for (F f : featColl) { String fStr; if (printer != null) { fStr = printer.apply(f).toString(); } else { fStr = f.toString(); } StringBuilder line = new StringBuilder(fStr); for (int s = fStr.length(); s < featureLength; s++) { line.append(' '); } for (L l : labels()) { String lStr = nf.format(weight(f, l)); line.append(' '); line.append(lStr); for (int s = lStr.length(); s < labelLength; s++) { line.append(' '); } } pw.println(line); } // Print totals, probs, etc. Counter<L> scores = scoresOf(example); StringBuilder footer = new StringBuilder("Total:"); for (int s = footer.length(); s < featureLength; s++) { footer.append(' '); } for (L l : labels()) { footer.append(' '); String str = nf.format(scores.getCount(l)); footer.append(str); for (int s = str.length(); s < labelLength; s++) { footer.append(' '); } } pw.println(footer); Distribution<L> distr = Distribution.distributionFromLogisticCounter(scores); footer = new StringBuilder("Prob:"); for (int s = footer.length(); s < featureLength; s++) { footer.append(' '); } for (L l : labels()) { footer.append(' '); String str = nf.format(distr.getCount(l)); footer.append(str); for (int s = str.length(); s < labelLength; s++) { footer.append(' '); } } pw.println(footer); } /** * This method returns a map from each label to a counter of feature weights for that label. * Useful for feature analysis. * * @return a map of counters */ public Map<L,Counter<F>> weightsAsMapOfCounters() { Map<L,Counter<F>> mapOfCounters = Generics.newHashMap(); for(L label : labelIndex){ int labelID = labelIndex.indexOf(label); Counter<F> c = new ClassicCounter<>(); mapOfCounters.put(label, c); for (F f : featureIndex) { c.incrementCount(f, weights[featureIndex.indexOf(f)][labelID]); } } return mapOfCounters; } public Counter<L> scoresOf(Datum<L, F> example, Collection<L> possibleLabels) { Counter<L> scores = new ClassicCounter<>(); for (L l : possibleLabels) { if (labelIndex.indexOf(l) == -1) { continue; } double score = scoreOf(example, l); scores.setCount(l, score); } return scores; } /* -- looks like a failed attempt at micro-optimization -- public L experimentalClassOf(Datum<L,F> example) { if(example instanceof RVFDatum<?, ?>) { throw new UnsupportedOperationException(); } int labelCount = weights[0].length; //System.out.printf("labelCount: %d\n", labelCount); Collection<F> features = example.asFeatures(); int[] featureInts = new int[features.size()]; int fI = 0; for (F feature : features) { featureInts[fI++] = featureIndex.indexOf(feature); } //System.out.println("Features: "+features); double bestScore = Double.NEGATIVE_INFINITY; int bestI = 0; for (int i = 0; i < labelCount; i++) { double score = 0; for (int j = 0; j < featureInts.length; j++) { if (featureInts[j] < 0) continue; score += weights[featureInts[j]][i]; } if (score > bestScore) { bestI = i; bestScore = score; } //System.out.printf("Score: %s(%d): %e\n", labelIndex.get(i), i, score); } //System.out.printf("label(%d): %s\n", bestI, labelIndex.get(bestI));; return labelIndex.get(bestI); } -- */ @Override public L classOf(Datum<L, F> example) { if(example instanceof RVFDatum<?, ?>)return classOfRVFDatum((RVFDatum<L,F>)example); Counter<L> scores = scoresOf(example); return Counters.argmax(scores); } private L classOfRVFDatum(RVFDatum<L, F> example) { Counter<L> scores = scoresOfRVFDatum(example); return Counters.argmax(scores); } @Override @Deprecated public L classOf(RVFDatum<L, F> example) { Counter<L> scores = scoresOf(example); return Counters.argmax(scores); } /** For Kryo -- can be private */ private LinearClassifier() { } /** Make a linear classifier from the parameters. The parameters are used, not copied. * * @param weights The parameters of the classifier. The first index is the * featureIndex value and second index is the labelIndex value. * @param featureIndex An index from F to integers used to index the features in the weights array * @param labelIndex An index from L to integers used to index the labels in the weights array */ public LinearClassifier(double[][] weights, Index<F> featureIndex, Index<L> labelIndex) { this.featureIndex = featureIndex; this.labelIndex = labelIndex; this.weights = weights; thresholds = new double[labelIndex.size()]; // Arrays.fill(thresholds, 0.0); // not needed; Java arrays zero initialized } // todo: This is unused and seems broken (ignores passed in thresholds) public LinearClassifier(double[][] weights, Index<F> featureIndex, Index<L> labelIndex, double[] thresholds) throws Exception { this.featureIndex = featureIndex; this.labelIndex = labelIndex; this.weights = weights; if (thresholds.length != labelIndex.size()) throw new Exception("Number of thresholds and number of labels do not match."); thresholds = new double[thresholds.length]; int curr = 0; for (double tval : thresholds) { thresholds[curr++] = tval; } Arrays.fill(thresholds, 0.0); } private static <F, L> Counter<Pair<F, L>> makeWeightCounter(double[] weights, Index<Pair<F, L>> weightIndex) { Counter<Pair<F,L>> weightCounter = new ClassicCounter<>(); for (int i = 0; i < weightIndex.size(); i++) { if (weights[i] == 0) { continue; // no need to save 0 weights } weightCounter.setCount(weightIndex.get(i), weights[i]); } return weightCounter; } public LinearClassifier(double[] weights, Index<Pair<F, L>> weightIndex) { this(makeWeightCounter(weights, weightIndex)); } public LinearClassifier(Counter<? extends Pair<F, L>> weightCounter) { this(weightCounter, new ClassicCounter<>()); } public LinearClassifier(Counter<? extends Pair<F, L>> weightCounter, Counter<L> thresholdsC) { Collection<? extends Pair<F, L>> keys = weightCounter.keySet(); featureIndex = new HashIndex<>(); labelIndex = new HashIndex<>(); for (Pair<F, L> p : keys) { featureIndex.add(p.first()); labelIndex.add(p.second()); } thresholds = new double[labelIndex.size()]; for (L label : labelIndex) { thresholds[labelIndex.indexOf(label)] = thresholdsC.getCount(label); } weights = new double[featureIndex.size()][labelIndex.size()]; Pair<F, L> tempPair = new Pair<>(); for (int f = 0; f < weights.length; f++) { for (int l = 0; l < weights[f].length; l++) { tempPair.first = featureIndex.get(f); tempPair.second = labelIndex.get(l); weights[f][l] = weightCounter.getCount(tempPair); } } } public void adaptWeights(Dataset<L, F> adapt,LinearClassifierFactory<L, F> lcf) { logger.info("before adapting, weights size="+weights.length); weights = lcf.adaptWeights(weights,adapt); logger.info("after adapting, weights size=" + weights.length); } public double[][] weights() { return weights; } public void setWeights(double[][] newWeights) { weights = newWeights; } /** * Loads a classifier from a file. * Simple convenience wrapper for IOUtils.readFromString. */ public static <L, F> LinearClassifier<L, F> readClassifier(String loadPath) { logger.info("Deserializing classifier from " + loadPath + "..."); try { ObjectInputStream ois = IOUtils.readStreamFromString(loadPath); LinearClassifier<L, F> classifier = ErasureUtils.<LinearClassifier<L, F>>uncheckedCast(ois.readObject()); ois.close(); return classifier; } catch (Exception e) { throw new RuntimeException("Deserialization failed: "+e.getMessage(), e); } } /** * Convenience wrapper for IOUtils.writeObjectToFile */ public static void writeClassifier(LinearClassifier<?, ?> classifier, String writePath) { try { IOUtils.writeObjectToFile(classifier, writePath); } catch (Exception e) { throw new RuntimeException("Serialization failed: "+e.getMessage(), e); } } /** * Saves this out to a standard text file, instead of as a serialized Java object. * NOTE: this currently assumes feature and weights are represented as Strings. * @param file String filepath to write out to. */ public void saveToFilename(String file) { try { File tgtFile = new File(file); BufferedWriter out = new BufferedWriter(new FileWriter(tgtFile)); // output index first, blank delimiter, outline feature index, then weights labelIndex.saveToWriter(out); featureIndex.saveToWriter(out); int numLabels = labelIndex.size(); int numFeatures = featureIndex.size(); for (int featIndex=0; featIndex<numFeatures; featIndex++) { for (int labelIndex=0;labelIndex<numLabels;labelIndex++) { out.write(String.valueOf(featIndex)); out.write(TEXT_SERIALIZATION_DELIMITER); out.write(String.valueOf(labelIndex)); out.write(TEXT_SERIALIZATION_DELIMITER); out.write(String.valueOf(weight(featIndex, labelIndex))); out.write("\n"); } } // write out thresholds: first item after blank is the number of thresholds, after is the threshold array values. out.write("\n"); out.write(String.valueOf(thresholds.length)); out.write("\n"); for (double val : thresholds) { out.write(String.valueOf(val)); out.write("\n"); } out.close(); } catch (Exception e) { logger.info("Error attempting to save classifier to file=" + file); logger.info(e); } } }