package edu.stanford.nlp.classify; import java.util.Arrays; import java.util.HashSet; import java.util.Set; import edu.stanford.nlp.classify.LogPrior.LogPriorType; import edu.stanford.nlp.optimization.AbstractCachingDiffFunction; import edu.stanford.nlp.optimization.HasRegularizerParamRange; /** * @author jtibs */ public class ShiftParamsLogisticObjectiveFunction extends AbstractCachingDiffFunction implements HasRegularizerParamRange { private final int[][] data; private final double[][] dataValues; private final int numClasses; private final int numFeatures; private final int[][] labels; private final int numL2Parameters; private final LogPrior prior; public ShiftParamsLogisticObjectiveFunction(int[][] data, double[][] dataValues, int[][] labels, int numClasses, int numFeatures, int numL2Parameters, LogPrior prior) { this.data = data; this.dataValues = dataValues; this.labels = labels; this.numClasses = numClasses; this.numFeatures = numFeatures; this.numL2Parameters = numL2Parameters; this.prior = prior; } @Override public int domainDimension() { return (numClasses - 1) * numFeatures; } @Override protected void calculate(double[] thetasArray) { clearResults(); double[][] thetas = new double[numClasses - 1][numFeatures]; LogisticUtils.unflatten(thetasArray, thetas); for (int i = 0; i < data.length; i++) { int[] featureIndices = data[i]; double[] featureValues = dataValues[i]; double[] sums = LogisticUtils.calculateSums(thetas, featureIndices, featureValues); for (int c = 0; c < numClasses; c++) { double sum = sums[c]; value -= sum * labels[i][c]; if (c == 0) continue; int offset = (c - 1) * numFeatures; double error = Math.exp(sum) - labels[i][c]; for (int f = 0; f < featureIndices.length; f++) { int index = featureIndices[f]; double x = featureValues[f]; derivative[offset + index] -= error * x; } } } // incorporate prior if (prior.getType().equals(LogPriorType.NULL)) return; double sigma = prior.getSigma(); for (int c = 0; c < numClasses; c++) { if (c == 0) continue; int offset = (c - 1) * numFeatures; for (int j = 0; j < numL2Parameters; j++) { double theta = thetasArray[offset + j]; value += theta * theta / (sigma * 2.0); derivative[offset + j] += theta / sigma; } } } private void clearResults() { value = 0.0; Arrays.fill(derivative, 0.0); } @Override public Set<Integer> getRegularizerParamRange(double[] x) { Set<Integer> result = new HashSet<>(); for (int i = numL2Parameters; i < x.length; i++) result.add(i); return result; } }