package edu.berkeley.nlp.crf; import java.io.Serializable; import edu.berkeley.nlp.classify.Encoding; import edu.berkeley.nlp.classify.FeatureExtractor; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.DoubleMatrices; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Pair; import edu.berkeley.nlp.util.PriorityQueue; public class Inference<V, E, F, L> implements Serializable { /** * */ private static final long serialVersionUID = 1948395432745606240L; private final Encoding<F, L> encoding; private final ScoreCalculator<V, E, F, L> scoreCalculator; public Inference(Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor) { this.encoding = encoding; this.scoreCalculator = new ScoreCalculator<V, E, F, L>(encoding, vertexExtractor, edgeExtractor); } public double[][] getAlphas(InstanceSequence<V, E, L> sequence, double[] w) { int n = sequence.getSequenceLength(); double[][] alpha = new double[n][]; alpha[0] = scoreCalculator.getVertexScores(sequence, 0, w); for (int i=1; i<n; i++) { double[][] scoreMatrix = scoreCalculator.getScoreMatrix(sequence, i, w); alpha[i] = DoubleMatrices.product(alpha[i-1], scoreMatrix); } return alpha; } public double[][] getBetas(InstanceSequence<V, E, L> sequence, double[] w) { int n = sequence.getSequenceLength(); double[][] beta = new double[n][]; beta[n-1] = DoubleArrays.constantArray(1.0, encoding.getNumLabels()); for (int i=n-2; i>=0; i--) { double[][] scoreMatrix = scoreCalculator.getScoreMatrix(sequence, i+1, w); beta[i] = DoubleMatrices.product(scoreMatrix, beta[i+1]); } return beta; } public Pair<int[][][][], double[][][]> getKBestChartAndBacktrace(InstanceSequence<V, E, L> sequence, double[] w, int k) { int n = sequence.getSequenceLength(); int numLabels = encoding.getNumLabels(); int[][][][] bestLabels = new int[n][numLabels][][]; double[][][] bestScores = new double[n][numLabels][]; double[] startScores = scoreCalculator.getLinearVertexScores(sequence, 0, w); for (int l=0; l<numLabels; l++) { bestScores[0][l] = new double[] { startScores[l] }; bestLabels[0][l] = new int[][] { new int[] {-1, 0 } }; } for (int i=1; i<n; i++) { double[][] scoreMatrix = scoreCalculator.getLinearScoreMatrix(sequence, i, w); for (int l=0; l<numLabels; l++) { PriorityQueue<Pair<Integer, Integer>> pq = new PriorityQueue<Pair<Integer, Integer>>(); for (int pl=0; pl<numLabels; pl++) { double edgeScore = scoreMatrix[pl][l]; for (int c=0; c<bestScores[i-1][pl].length; c++) { double totalScore = edgeScore + bestScores[i-1][pl][c]; pq.add(Pair.makePair(pl, c), totalScore); } } int cands = Math.min(k, pq.size()); bestScores[i][l] = new double[cands]; bestLabels[i][l] = new int[cands][2]; for (int c=0; c<cands; c++) { bestScores[i][l][c] = pq.getPriority(); Pair<Integer, Integer> backtrace = pq.next(); bestLabels[i][l][c][0] = backtrace.getFirst(); bestLabels[i][l][c][1] = backtrace.getSecond(); } } } return Pair.makePair(bestLabels, bestScores); } public double[][] getVertexPosteriors(double[][] alpha, double[][] beta) { double[][] p = new double[alpha.length][encoding.getNumLabels()]; for (int i=0; i<p.length; i++) { for (int l=0; l<p[i].length; l++) { p[i][l] = alpha[i][l] * beta[i][l]; } ArrayUtil.normalize(p[i]); } return p; } public double[][][] getEdgePosteriors(InstanceSequence<V, E, L> sequence, double[] w, double[][] alpha, double[][] beta) { int numLabels = encoding.getNumLabels(); int n = sequence.getSequenceLength(); double[][][] p = new double[n][numLabels][numLabels]; for (int i=1; i<p.length; i++) { double[][] scoreMatrix = scoreCalculator.getScoreMatrix(sequence, i, w); for (int lp=0; lp<numLabels; lp++) { for (int lc=0; lc<numLabels; lc++) { p[i][lp][lc] = alpha[i-1][lp] * scoreMatrix[lp][lc] * beta[i][lc]; } } ArrayUtil.normalize(p[i]); } return p; } public double getNormalizationConstant(double[][] alpha, double[][] beta) { int anyIndex = 0; double[] p = new double[alpha[anyIndex].length]; for (int l=0; l<p.length; l++) { p[l] = alpha[anyIndex][l] * beta[anyIndex][l]; } return ArrayUtil.sum(p); } }