package edu.berkeley.nlp.crf;
import java.util.ArrayList;
import java.util.List;
import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
public class Counts<V, E, F, L> {
private final Encoding<F, L> encoding;
private final FeatureExtractor<V, F> vertexExtractor;
private final FeatureExtractor<E, F> edgeExtractor;
private final Inference<V, E, F, L> inf;
public Counts(Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor) {
this.encoding = encoding;
this.vertexExtractor = vertexExtractor;
this.edgeExtractor = edgeExtractor;
this.inf = new Inference<V, E, F, L>(encoding, vertexExtractor, edgeExtractor);
}
public List<Counter<F>> getEmpiricalCounts(List<? extends LabeledInstanceSequence<V, E, L>> sequences) {
int numLabels = encoding.getNumLabels();
List<Counter<F>> counts = new ArrayList<Counter<F>>(numLabels);
for (int l=0; l<numLabels; l++) {
counts.add(new Counter<F>());
}
for (LabeledInstanceSequence<V, E, L> s: sequences) {
for (int i=0; i<s.getSequenceLength(); i++) {
Counter<F> vertexFeatures = vertexExtractor.extractFeatures(s.getVertexInstance(i));
int goldLabelIndex = encoding.getLabelIndex(s.getGoldLabel(i));
counts.get(goldLabelIndex).incrementAll(vertexFeatures);
if (i>0) {
Counter<F> edgeFeatures = edgeExtractor.extractFeatures(s.getEdgeInstance(i, s.getGoldLabel(i-1)));
counts.get(goldLabelIndex).incrementAll(edgeFeatures);
}
}
}
return counts;
}
public Pair<Double,List<Counter<F>>> getLogNormalizationAndExpectedCounts(List<? extends InstanceSequence<V, E, L>> sequences, double[] w) {
int numLabels = encoding.getNumLabels();
List<Counter<F>> counts = new ArrayList<Counter<F>>(numLabels);
for (int l=0; l<numLabels; l++) {
counts.add(new Counter<F>());
}
double totalLogZ = 0.0;
Logger.startTrack("Computing expected counts");
int index = 0;
for (InstanceSequence<V, E, L> s : sequences) {
double[][] alpha = inf.getAlphas(s, w);
double[][] beta = inf.getBetas(s, w);
totalLogZ += Math.log(inf.getNormalizationConstant(alpha, beta));
double[][] vertexPosteriors = inf.getVertexPosteriors(alpha, beta);
double[][][] edgePosteriors = inf.getEdgePosteriors(s, w, alpha, beta);
for (int i=0; i<s.getSequenceLength(); i++) {
Counter<F> vertexFeatures = vertexExtractor.extractFeatures(s.getVertexInstance(i));
for (int l=0; l<numLabels; l++) {
counts.get(l).incrementAll(vertexFeatures.scaledClone(vertexPosteriors[i][l]));
}
if (i>0) {
for (int pl=0; pl<numLabels; pl++) {
Counter<F> edgeFeatures = edgeExtractor.extractFeatures(s.getEdgeInstance(i, encoding.getLabel(pl)));
for (int cl=0; cl<numLabels; cl++) {
counts.get(cl).incrementAll(edgeFeatures.scaledClone(edgePosteriors[i][pl][cl]));
}
}
}
}
Logger.logs("Processed %d/%d sentences", ++index, sequences.size());
}
Logger.endTrack();
return Pair.makePair(totalLogZ, counts);
}
}