package edu.berkeley.nlp.crf;
import java.io.Serializable;
import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
public class ScoreCalculator<V, E, F, L> implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6864706229279071608L;
private final Encoding<F, L> encoding;
private final FeatureExtractor<V, F> vertexExtractor;
private final FeatureExtractor<E, F> edgeExtractor;
private final IndexLinearizer il;
public ScoreCalculator(Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor,
FeatureExtractor<E, F> edgeExtractor) {
this.encoding = encoding;
this.vertexExtractor = vertexExtractor;
this.edgeExtractor = edgeExtractor;
this.il = new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
}
public double[][] getScoreMatrix(InstanceSequence<V, E, L> sequence, int index, double[] w) {
double[][] M = getLinearScoreMatrix(sequence, index, w);
for (int i=0; i<M.length; i++) {
M[i] = ArrayUtil.exp(M[i]);
}
return M;
}
public double[] getVertexScores(InstanceSequence<V, E, L> sequence, int index, double[] w) {
return ArrayUtil.exp(getLinearVertexScores(sequence, index, w));
}
public double[][] getLinearScoreMatrix(InstanceSequence<V, E, L> sequence, int index, double[] w) {
int numLabels = encoding.getNumLabels();
double[][] M = new double[numLabels][numLabels];
Counter<F> vertexFeatures = vertexExtractor.extractFeatures(sequence.getVertexInstance(index));
for (int vc = 0; vc<numLabels; vc++) {
double vertexScore = dotProduct(vertexFeatures, vc, w);
for (int vp = 0; vp<numLabels; vp++) {
L previousLabel = encoding.getLabel(vp);
Counter<F> edgeFeatures = edgeExtractor.extractFeatures(sequence.getEdgeInstance(index, previousLabel));
double edgeScore = dotProduct(edgeFeatures, vc, w);
M[vp][vc] = vertexScore + edgeScore;
}
}
return M;
}
public double[] getLinearVertexScores(InstanceSequence<V, E, L> sequence, int index, double[] w) {
int numLabels = encoding.getNumLabels();
double[] s = new double[numLabels];
Counter<F> vertexFeatures = vertexExtractor.extractFeatures(sequence.getVertexInstance(index));
for (int vc = 0; vc<numLabels; vc++) {
double vertexScore = dotProduct(vertexFeatures, vc, w);
s[vc] = vertexScore;
}
return s;
}
private double dotProduct(Counter<F> features, int labelIndex, double[] w) {
double val = 0.0;
for (F feature : features.keySet()) {
if (encoding.hasFeature(feature)) {
int featureIndex = encoding.getFeatureIndex(feature);
int linearIndex = il.getLinearIndex(featureIndex, labelIndex);
val += features.getCount(feature) * w[linearIndex];
}
}
return val;
}
}