package edu.berkeley.nlp.classify; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.math.DifferentiableFunction; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.LBFGSMinimizer; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.util.Counter; import edu.berkeley.nlp.util.Indexer; import edu.berkeley.nlp.util.Logger; import edu.berkeley.nlp.util.Pair; /** * Maximum entropy classifier for assignment 2. * * @author Dan Klein */ public class MaximumEntropyClassifier<I, F, L> implements ProbabilisticClassifier<I, L>, Serializable { private static final long serialVersionUID = 1L; /** * Factory for training MaximumEntropyClassifiers. */ public static class Factory<I, F, L> implements ProbabilisticClassifierFactory<I, L> { double sigma; int iterations; FeatureExtractor<I, F> featureExtractor; public ProbabilisticClassifier<I, L> trainClassifier( List<LabeledInstance<I, L>> trainingData) { return trainClassifier(trainingData, true); } public ProbabilisticClassifier<I, L> trainClassifier( List<LabeledInstance<I, L>> trainingData, boolean verbose) { // build data encodings so the inner loops can be efficient if (verbose) Logger.i().startTrack("Building encoding"); Encoding<F, L> encoding = buildEncoding(trainingData); IndexLinearizer indexLinearizer = buildIndexLinearizer(encoding); double[] initialWeights = buildInitialWeights(indexLinearizer); EncodedDatum[] data = encodeData(trainingData, encoding); if (verbose) Logger.i().endTrack(); // build a minimizer object LBFGSMinimizer minimizer = new LBFGSMinimizer(iterations); // build the objective function for this data DifferentiableFunction objective = new ObjectiveFunction<F, L>(encoding, data, indexLinearizer, sigma); // learn our voting weights if (verbose) Logger.i().startTrack("Training weights"); double[] weights = minimizer.minimize(objective, initialWeights, 1e-4, verbose); if (verbose) Logger.i().endTrack(); // build a classifer using these weights (and the data encodings) return new MaximumEntropyClassifier<I, F, L>(weights, encoding, indexLinearizer, featureExtractor); } private double[] buildInitialWeights(IndexLinearizer indexLinearizer) { return DoubleArrays.constantArray(0.0, indexLinearizer.getNumLinearIndexes()); } private IndexLinearizer buildIndexLinearizer(Encoding<F, L> encoding) { return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels()); } private Encoding<F, L> buildEncoding(List<LabeledInstance<I, L>> data) { Indexer<F> featureIndexer = new Indexer<F>(); Indexer<L> labelIndexer = new Indexer<L>(); for (LabeledInstance<I, L> labeledInstance : data) { L label = labeledInstance.getLabel(); Counter<F> features = featureExtractor.extractFeatures(labeledInstance .getInput()); LabeledFeatureVector<F, L> labeledDatum = new BasicLabeledFeatureVector<F, L>( label, features); labelIndexer.getIndex(labeledDatum.getLabel()); for (F feature : labeledDatum.getFeatures().keySet()) { featureIndexer.getIndex(feature); } } return new Encoding<F, L>(featureIndexer, labelIndexer); } private EncodedDatum[] encodeData(List<LabeledInstance<I, L>> data, Encoding<F, L> encoding) { EncodedDatum[] encodedData = new EncodedDatum[data.size()]; for (int i = 0; i < data.size(); i++) { LabeledInstance<I, L> labeledInstance = data.get(i); L label = labeledInstance.getLabel(); Counter<F> features = featureExtractor.extractFeatures(labeledInstance .getInput()); LabeledFeatureVector<F, L> labeledFeatureVector = new BasicLabeledFeatureVector<F, L>( label, features); encodedData[i] = EncodedDatum.encodeLabeledDatum(labeledFeatureVector, encoding); } return encodedData; } /** * Sigma controls the variance on the prior / penalty term. 1.0 is a * reasonable value for large problems, bigger sigma means LESS * smoothing. Zero sigma is a special indicator that no smoothing is to * be done. <p/> Iterations determines the maximum number of iterations * the optimization code can take before stopping. */ public Factory(double sigma, int iterations, FeatureExtractor<I, F> featureExtractor) { this.sigma = sigma; this.iterations = iterations; this.featureExtractor = featureExtractor; } } /** * This is the MaximumEntropy objective function: the (negative) log * conditional likelihood of the training data, possibly with a penalty for * large weights. Note that this objective get MINIMIZED so it's the * negative of the objective we normally think of. */ public static class ObjectiveFunction<F, L> implements DifferentiableFunction { IndexLinearizer indexLinearizer; Encoding<F, L> encoding; EncodedDatum[] data; double sigma; double lastValue; double[] lastDerivative; double[] lastX; public int dimension() { return indexLinearizer.getNumLinearIndexes(); } public double valueAt(double[] x) { ensureCache(x); return lastValue; } public double[] derivativeAt(double[] x) { ensureCache(x); return lastDerivative; } private void ensureCache(double[] x) { if (requiresUpdate(lastX, x)) { Pair<Double, double[]> currentValueAndDerivative = calculate(x); lastValue = currentValueAndDerivative.getFirst(); lastDerivative = currentValueAndDerivative.getSecond(); lastX = x; } } private boolean requiresUpdate(double[] lastX, double[] x) { if (lastX == null) return true; for (int i = 0; i < x.length; i++) { if (lastX[i] != x[i]) return true; } return false; } /** * The most important part of the classifier learning process! This * method determines, for the given weight vector x, what the (negative) * log conditional likelihood of the data is, as well as the derivatives * of that likelihood wrt each weight parameter. */ private Pair<Double, double[]> calculate(double[] x) { double objective = 0.0; double[] derivatives = DoubleArrays.constantArray(0.0, dimension()); double[] classActivations = new double[encoding.getNumLabels()]; double[] classPosteriors = new double[encoding.getNumLabels()]; for (EncodedDatum datum : data) { // For each datum we get the activation for each class // and then the posteriors int numActiveFeatures = datum.getNumActiveFeatures(); for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { double activation = 0.0; for (int num = 0; num < numActiveFeatures; ++num) { int featureIndex = datum.getFeatureIndex(num); double featureCount = datum.getFeatureCount(num); int linearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, labelIndex); activation += x[linearFeatureIndex] * featureCount; } classActivations[labelIndex] = activation; } double logSumActivation = SloppyMath.logAdd(classActivations); int correctLabelIndex = datum.getLabelIndex(); // Log Prob objective += (classActivations[correctLabelIndex] - logSumActivation); // Class Posteriors for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { classPosteriors[labelIndex] = SloppyMath .exp(classActivations[labelIndex] - logSumActivation); } // Derivative: Feature Expectations for (int num = 0; num < numActiveFeatures; ++num) { int featureIndex = datum.getFeatureIndex(num); int correctLinearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, correctLabelIndex); double featureCount = datum.getFeatureCount(num); derivatives[correctLinearFeatureIndex] += featureCount; for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { int linearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, labelIndex); double classProb = classPosteriors[labelIndex]; derivatives[linearFeatureIndex] -= classProb * featureCount; } } } // Scale by -1 since we are minimizing negative log-liklihood objective *= -1; DoubleArrays.scale(derivatives, -1); // L2 Penalty for (int i = 0; i < x.length; ++i) { double weight = x[i]; objective += (weight * weight) / (2 * sigma * sigma); derivatives[i] += (weight) / (sigma * sigma); } return new Pair<Double, double[]>(objective, derivatives); } public ObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] data, IndexLinearizer indexLinearizer, double sigma) { this.indexLinearizer = indexLinearizer; this.encoding = encoding; this.data = data; this.sigma = sigma; } public double[] unregularizedDerivativeAt(double[] x) { // TODO Auto-generated method stub return null; } } /** * EncodedDatums are sparse representations of (labeled) feature count * vectors for a given data point. Use getNumActiveFeatures() to see how * many features have non-zero count in a datum. Then, use getFeatureIndex() * and getFeatureCount() to retreive the number and count of each non-zero * feature. Use getLabelIndex() to get the label's number. */ public static class EncodedDatum { public static <F, L> EncodedDatum encodeDatum(FeatureVector<F> featureVector, Encoding<F, L> encoding) { Counter<F> features = featureVector.getFeatures(); Counter<F> knownFeatures = new Counter<F>(); for (F feature : features.keySet()) { if (encoding.getFeatureIndex(feature) < 0) continue; knownFeatures.incrementCount(feature, features.getCount(feature)); } int numActiveFeatures = knownFeatures.keySet().size(); int[] featureIndexes = new int[numActiveFeatures]; double[] featureCounts = new double[knownFeatures.keySet().size()]; int i = 0; for (F feature : knownFeatures.keySet()) { int index = encoding.getFeatureIndex(feature); double count = knownFeatures.getCount(feature); featureIndexes[i] = index; featureCounts[i] = count; i++; } EncodedDatum encodedDatum = new EncodedDatum(-1, featureIndexes, featureCounts); return encodedDatum; } public static <F, L> EncodedDatum encodeLabeledDatum( LabeledFeatureVector<F, L> labeledDatum, Encoding<F, L> encoding) { EncodedDatum encodedDatum = encodeDatum(labeledDatum, encoding); encodedDatum.labelIndex = encoding.getLabelIndex(labeledDatum.getLabel()); return encodedDatum; } int labelIndex; int[] featureIndexes; double[] featureCounts; public int getLabelIndex() { return labelIndex; } public int getNumActiveFeatures() { return featureCounts.length; } public int getFeatureIndex(int num) { return featureIndexes[num]; } public double getFeatureCount(int num) { return featureCounts[num]; } public EncodedDatum(int labelIndex, int[] featureIndexes, double[] featureCounts) { this.labelIndex = labelIndex; this.featureIndexes = featureIndexes; this.featureCounts = featureCounts; } } private double[] weights; private Encoding<F, L> encoding; private IndexLinearizer indexLinearizer; private transient FeatureExtractor<I, F> featureExtractor; /** * */ public void setFeatureExtractor(FeatureExtractor<I, F> featureExtractor) { this.featureExtractor = featureExtractor; } /** * Calculate the log probabilities of each class, for the given datum * (feature bundle). Note that the weighted votes (refered to as * activations) are *almost* log probabilities, but need to be normalized. */ private static <F, L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) { double[] logProbabilities = new double[encoding.getNumLabels()]; for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { for (int num = 0; num < datum.getNumActiveFeatures(); ++num) { int featureIndex = datum.getFeatureIndex(num); double featureCount = datum.getFeatureCount(num); int linearFeatureIndex = indexLinearizer.getLinearIndex(featureIndex, labelIndex); logProbabilities[labelIndex] += weights[linearFeatureIndex] * featureCount; } } double logSumProb = SloppyMath.logAdd(logProbabilities); for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { logProbabilities[labelIndex] -= logSumProb; } return logProbabilities; } public Counter<L> getProbabilities(I input) { FeatureVector<F> featureVector = new BasicFeatureVector<F>(featureExtractor .extractFeatures(input)); return getProbabilities(featureVector); } private Counter<L> getProbabilities(FeatureVector<F> featureVector) { EncodedDatum encodedDatum = EncodedDatum.encodeDatum(featureVector, encoding); double[] logProbabilities = getLogProbabilities(encodedDatum, weights, encoding, indexLinearizer); return logProbabiltyArrayToProbabiltyCounter(logProbabilities); } private Counter<L> logProbabiltyArrayToProbabiltyCounter(double[] logProbabilities) { Counter<L> probabiltyCounter = new Counter<L>(); for (int labelIndex = 0; labelIndex < logProbabilities.length; labelIndex++) { double logProbability = logProbabilities[labelIndex]; double probability = Math.exp(logProbability); L label = encoding.getLabel(labelIndex); probabiltyCounter.setCount(label, probability); } return probabiltyCounter; } public L getLabel(I input) { return getProbabilities(input).argMax(); } public MaximumEntropyClassifier(double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer, FeatureExtractor<I, F> featureExtractor) { this.weights = weights; this.encoding = encoding; this.indexLinearizer = indexLinearizer; this.featureExtractor = featureExtractor; } public static void main(String[] args) { // Execution.init(args); // create datums LabeledInstance<String[], String> datum1 = new LabeledInstance<String[], String>( "cat", new String[] { "fuzzy", "claws", "small" }); LabeledInstance<String[], String> datum2 = new LabeledInstance<String[], String>( "bear", new String[] { "fuzzy", "claws", "big" }); LabeledInstance<String[], String> datum3 = new LabeledInstance<String[], String>( "cat", new String[] { "claws", "medium" }); LabeledInstance<String[], String> datum4 = new LabeledInstance<String[], String>( "cat", new String[] { "claws", "small" }); // create training set List<LabeledInstance<String[], String>> trainingData = new ArrayList<LabeledInstance<String[], String>>(); trainingData.add(datum1); trainingData.add(datum2); trainingData.add(datum3); // create test set List<LabeledInstance<String[], String>> testData = new ArrayList<LabeledInstance<String[], String>>(); testData.add(datum4); // build classifier FeatureExtractor<String[], String> featureExtractor = new FeatureExtractor<String[], String>() { /** * */ private static final long serialVersionUID = 8296036312980792350L; public Counter<String> extractFeatures(String[] featureArray) { return new Counter<String>(Arrays.asList(featureArray)); } }; MaximumEntropyClassifier.Factory<String[], String, String> maximumEntropyClassifierFactory = new MaximumEntropyClassifier.Factory<String[], String, String>( 1.0, 20, featureExtractor); ProbabilisticClassifier<String[], String> maximumEntropyClassifier = maximumEntropyClassifierFactory .trainClassifier(trainingData); System.out.println("Probabilities on test instance: " + maximumEntropyClassifier.getProbabilities(datum4.getInput())); // Execution.finish(); } }