package edu.berkeley.nlp.crf;
import java.util.List;
import java.util.Map;
import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Pair;
public class CRFObjectiveFunction<V, E, F, L> implements DifferentiableFunction {
private final List<? extends LabeledInstanceSequence<V, E, L>> trainingData;
private final Encoding<F, L> encoding;
private final Counts<V, E, F, L> counts;
private final IndexLinearizer il;
private final double sigma;
double lastValue;
double[] lastDerivative;
double[] lastX;
public CRFObjectiveFunction(List<? extends LabeledInstanceSequence<V, E, L>> trainingData, Encoding<F, L> encoding,
FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor, double sigma) {
this.trainingData = trainingData;
this.encoding = encoding;
this.counts = new Counts<V, E, F, L>(encoding, vertexExtractor, edgeExtractor);
this.il = new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
this.sigma = sigma;
}
public int dimension() {
return il.getNumLinearIndexes();
}
public double valueAt(double[] x) {
ensureCache(x);
return lastValue;
}
public double[] derivativeAt(double[] x) {
ensureCache(x);
return lastDerivative;
}
private void ensureCache(double[] x) {
if (requiresUpdate(lastX, x)) {
Pair<Double, double[]> currentValueAndDerivative = calculate(x);
lastValue = currentValueAndDerivative.getFirst();
lastDerivative = currentValueAndDerivative.getSecond();
lastX = x;
}
}
private boolean requiresUpdate(double[] lastX, double[] x) {
if (lastX == null) return true;
for (int i = 0; i < x.length; i++) {
if (lastX[i] != x[i]) return true;
}
return false;
}
private Pair<Double, double[]> calculate(double[] x) {
double objective = 0.0;
double[] derivatives = new double[dimension()];
List<Counter<F>> empiricalCounts = counts.getEmpiricalCounts(trainingData);
for (int l=0; l<empiricalCounts.size(); l++) {
for (Map.Entry<F, Double> entry : empiricalCounts.get(l).entrySet()) {
int index = il.getLinearIndex(encoding.getFeatureIndex(entry.getKey()), l);
objective -= entry.getValue() * x[index];
derivatives[index] -= entry.getValue();
}
}
Pair<Double, List<Counter<F>>> results = counts.getLogNormalizationAndExpectedCounts(trainingData, x);
objective += results.getFirst();
List<Counter<F>> expectedCounts = results.getSecond();
for (int l=0; l<expectedCounts.size(); l++) {
for (Map.Entry<F, Double> entry : expectedCounts.get(l).entrySet()) {
int index = il.getLinearIndex(encoding.getFeatureIndex(entry.getKey()), l);
derivatives[index] += entry.getValue();
}
}
for (int i = 0; i < x.length; ++i) {
double weight = x[i];
objective += (weight * weight) / (2 * sigma * sigma);
derivatives[i] += (weight) / (sigma * sigma);
}
return Pair.makePair(objective, derivatives);
}
}