package edu.stanford.nlp.classify; import java.util.Arrays; import java.util.List; import edu.stanford.nlp.classify.LogPrior.LogPriorType; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.optimization.DiffFunction; import edu.stanford.nlp.optimization.QNMinimizer; import edu.stanford.nlp.util.logging.Redwood; /** * @author jtibs */ public class ShiftParamsLogisticClassifierFactory<L, F> implements ClassifierFactory<L, F, MultinomialLogisticClassifier<L, F>> { private static final long serialVersionUID = -8977510677251295037L; private int[][] data; private double[][] dataValues; private int[] labels; private int numClasses; private int numFeatures; private LogPrior prior; private double lambda; public ShiftParamsLogisticClassifierFactory() { this(new LogPrior(LogPriorType.NULL), 0.1); } public ShiftParamsLogisticClassifierFactory(double lambda) { this(new LogPrior(LogPriorType.NULL), lambda); } // NOTE: the current implementation only supports quadratic priors (or no prior) public ShiftParamsLogisticClassifierFactory(LogPrior prior, double lambda) { this.prior = prior; this.lambda = lambda; } public MultinomialLogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) { numClasses = dataset.numClasses(); numFeatures = dataset.numFeatures(); data = dataset.getDataArray(); if (dataset instanceof RVFDataset<?, ?>) { dataValues = dataset.getValuesArray(); } else { dataValues = LogisticUtils.initializeDataValues(data); } augmentFeatureMatrix(data, dataValues); labels = dataset.getLabelsArray(); return new MultinomialLogisticClassifier<>(trainWeights(), dataset.featureIndex, dataset.labelIndex); } private double[][] trainWeights() { QNMinimizer minimizer = new QNMinimizer(15, true); minimizer.useOWLQN(true, lambda); DiffFunction objective = new ShiftParamsLogisticObjectiveFunction(data, dataValues, convertLabels(labels), numClasses, numFeatures + data.length, numFeatures, prior); double[] augmentedThetas = new double[(numClasses - 1) * (numFeatures + data.length)]; augmentedThetas = minimizer.minimize(objective, 1e-4, augmentedThetas); // calculate number of non-zero parameters, for debugging int count = 0; for (int j = numFeatures; j < augmentedThetas.length; j++) { if (augmentedThetas[j] != 0) count++; } Redwood.log("NUM NONZERO PARAMETERS: " + count); double[][] thetas = new double[numClasses - 1][numFeatures]; LogisticUtils.unflatten(augmentedThetas, thetas); return thetas; } // augments the feature matrix to account for shift parameters, setting X := [X|I] private void augmentFeatureMatrix(int[][] data, double[][] dataValues) { for (int i = 0; i < data.length; i++) { int newLength = data[i].length + 1; data[i] = Arrays.copyOf(data[i], newLength); data[i][newLength - 1] = i + numFeatures; dataValues[i] = Arrays.copyOf(dataValues[i], newLength); dataValues[i][newLength - 1] = 1; } } // convert labels to form that the objective function expects private int[][] convertLabels(int[] labels) { int[][] result = new int[labels.length][numClasses]; for (int i = 0; i < labels.length; i++) { result[i][labels[i]] = 1; } return result; } }