package edu.berkeley.nlp.classify; import java.util.Collection; import java.util.List; import edu.berkeley.nlp.math.CachingDifferentiableFunction; import edu.berkeley.nlp.math.GradientMinimizer; import edu.berkeley.nlp.math.LBFGSMinimizer; import edu.berkeley.nlp.util.CollectionUtils; import edu.berkeley.nlp.util.Counter; import edu.berkeley.nlp.util.Pair; public class LinearRegression<I> { private FeatureExtractor<I,String> featureExtractor ; private double[] weights; private FeatureManager featureManager ; public static class Factory<I> { double[] weights ; FeatureManager featureManager ; FeatureExtractor<I, String> featureExtractor; Collection<Pair<I,Double>> trainingData; public Factory(FeatureExtractor<I, String> featureExtractor) { this.featureExtractor = featureExtractor; this.featureManager = new FeatureManager(); } private Counter<Feature> getFeatures(I input) { Counter<String> strCounts = featureExtractor.extractFeatures(input); Counter<Feature> featCounts = new Counter<Feature>(); for (String f: strCounts.keySet()) { double count = strCounts.getCount(f); Feature feat = featureManager.getFeature(f); featCounts.setCount(feat, count); } return featCounts; } private double getScore(Counter<Feature> featureCounts) { double score = 0.0; for (Feature feat: featureCounts.keySet()) { double count = featureCounts.getCount(feat); score += count * weights[feat.getIndex()]; } return score; } private class ObjectiveFunction extends CachingDifferentiableFunction { @Override protected Pair<Double, double[]> calculate(double[] x) { weights = x; double objective = 0.0; double[] gradient = new double[dimension()]; for (Pair<I,Double> datum: trainingData) { I input = datum.getFirst(); Counter<Feature> featCounts = getFeatures(input); double guessResponse = getScore(featCounts); double goldResponse = datum.getSecond(); double diff = (guessResponse - goldResponse); objective += 0.5 * diff * diff; for (Feature feat: featCounts.keySet()) { double count = featCounts.getCount(feat); gradient[feat.getIndex()] += count * diff; } } // TODO Auto-generated method stub return Pair.newPair(objective, gradient); } @Override public int dimension() { // TODO Auto-generated method stub return featureManager.getNumFeatures(); } public double[] unregularizedDerivativeAt(double[] x) { // TODO Auto-generated method stub return null; } } private void extractAllFeatures() { for (Pair<I,Double> datum: trainingData) { Counter<String> counts = featureExtractor.extractFeatures(datum.getFirst()); for (String f: counts.keySet()) { featureManager.getFeature(f); } } featureManager.lock(); } private String examineWeights() { Counter<Feature> counts = new Counter<Feature>(); for (int i=0; i < weights.length; ++i) { Feature feat = featureManager.getFeature(i); counts.setCount(feat, weights[i]); } return counts.toString(); } public LinearRegression<I> train(Collection<Pair<I,Double>> trainingData) { this.trainingData = trainingData; extractAllFeatures(); ObjectiveFunction objFn = new ObjectiveFunction(); GradientMinimizer gradMinimizer = new LBFGSMinimizer(); double[] initial = new double[objFn.dimension()]; this.weights = gradMinimizer.minimize(objFn, initial, 1.0e-4); return new LinearRegression<I>(featureExtractor, featureManager, weights); } } private LinearRegression(FeatureExtractor<I, String> featureExtractor, FeatureManager featureManager, double[] weights) { this.featureExtractor = featureExtractor; this.featureManager = featureManager; this.weights = weights; } public double getResponse(I input) { Counter<String> featCounts = featureExtractor.extractFeatures(input); double score = 0.0; for (String f: featCounts.keySet()) { double count = featCounts.getCount(f); Feature feat = featureManager.getFeature(f); score += count * weights[feat.getIndex()]; } return score; } public static void main(String[] args) { List<String> elem1 = CollectionUtils.makeList("a","b","c"); List<String> elem2 = CollectionUtils.makeList("a","b"); Pair<List<String>, Double> d1 = Pair.newPair(elem1, 3.0); Pair<List<String>, Double> d2 = Pair.newPair(elem2, 2.0); FeatureExtractor<List<String>, String> featExtractor = new FeatureExtractor<List<String>, String>() { public Counter<String> extractFeatures(List<String> instance) { Counter<String> counts = new Counter<String>(); for (String elem: instance) { counts.incrementCount(elem, 1.0); } // TODO Auto-generated method stub return counts; } }; LinearRegression.Factory<List<String>> factory = new LinearRegression.Factory<List<String>>(featExtractor); List<Pair<List<String>,Double>> datums = CollectionUtils.makeList(d1,d2); LinearRegression<List<String>> linearRegressionModel = factory.train(datums); double guess = linearRegressionModel.getResponse(elem1); System.out.println("guess: " + guess); } }