/* Copyright (C) 2006, Xuan-Hieu Phan Email: hieuxuan@ecei.tohoku.ac.jp pxhieu@gmail.com URL: http://www.hori.ecei.tohoku.ac.jp/~hieuxuan Graduate School of Information Sciences, Tohoku University */ package crf.tagger; import java.io.*; import java.util.*; import java.util.Arrays; public class Viterbi { public Model model = null; int numLabels = 0; // state potentials DoubleMatrix Mi = null; // edge potentials DoubleVector Vi = null; public class PairDblInt { public double first = 0.0; public int second = -1; } // enf of class PairDblInt public int memorySize = 0; public PairDblInt[][] memory = null; public Viterbi() { } public void init(Model model) { this.model = model; numLabels = model.taggerMaps.numLabels(); Mi = new DoubleMatrix(numLabels, numLabels); Vi = new DoubleVector(numLabels); allocateMemory(100); // compute Mi once at initialization computeMi(false); } public void allocateMemory(int memorySize) { this.memorySize = memorySize; memory = new PairDblInt[memorySize][numLabels]; for (int i = 0; i < memorySize; i++) { for (int j = 0; j < numLabels; j++) { memory[i][j] = new PairDblInt(); } } } public void computeMi(boolean isExp) { Mi.assign(0.0); model.taggerFGen.startScanEFeatures(); while (model.taggerFGen.hasNextEFeature()) { Feature f = model.taggerFGen.nextEFeature(); if (f.ftype == Feature.EDGE_FEATURE1) { Mi.mtrx[f.yp][f.y] += model.lambda[f.idx] * f.val; } } if (isExp) { for (int i = 0; i < Mi.rows; i++) { for (int j = 0; j < Mi.cols; j++) { Mi.mtrx[i][j] = Math.exp(Mi.mtrx[i][j]); } } } } public void computeVi(List seq, int pos, DoubleVector Vi, boolean isExp) { Vi.assign(0.0); // start scan features for sequence "seq" at position "pos" model.taggerFGen.startScanSFeaturesAt(seq, pos); // examine all features at position "pos" while (model.taggerFGen.hasNextSFeature()) { Feature f = model.taggerFGen.nextSFeature(); if (f.ftype == Feature.STAT_FEATURE1) { Vi.vect[f.y] += model.lambda[f.idx] * f.val; } } // take exponential operator if (isExp) { for (int i = 0; i < Vi.len; i++) { Vi.vect[i] = Math.exp(Vi.vect[i]); } } } // list is a List of PairDblInt public double sum(PairDblInt[] cols) { double res = 0.0; for (int i = 0; i < numLabels; i++) { res += cols[i].first; } if (res < 1 && res > -1) { res = 1; } return res; } // list is a List of PairDblInt public void divide(PairDblInt[] cols, double val) { for (int i = 0; i < numLabels; i++) { cols[i].first /= val; } } // list is a List of PairDblInt public int findMax(PairDblInt[] cols) { int maxIdx = 0; double maxVal = -1.0; for (int i = 0; i < numLabels; i++) { if (cols[i].first > maxVal) { maxVal = cols[i].first; maxIdx = i; } } return maxIdx; } class SortBeam implements Comparator<PairDblInt> { public int compare(PairDblInt a, PairDblInt b) { return ((Double)b.first).compareTo( a.first); } } private Set construct_beam(PairDblInt[] cols, int beam_size) { Set best = new HashSet(); PairDblInt[] cols2 = new PairDblInt[numLabels]; for (int i = 0; i < numLabels; i++) { cols2[i] = new PairDblInt(); cols2[i].first = cols[i].first; cols2[i].second = i; } Arrays.sort(cols2, new SortBeam()); for (int i = 0; i < Math.min(beam_size, cols2.length); i++) { best.add(cols2[i].second); } return best; } public void viterbiInference(List seq) { int i, j, k; // add in beam search int beam_size = 50; Set full_beam = new HashSet(); int seqLen = seq.size(); if (seqLen <= 0) { return; } for (i=0; i < numLabels; i++) { full_beam.add(i); } if (memorySize < seqLen) { allocateMemory(seqLen); } // compute Vi for the first position in the sequence int node_label[][] = new int[seqLen][numLabels]; //computeVi(seq, 0, Vi, true); computeVi(seq, 0, Vi, false); System.out.println("LATTICE: START"); int cur_node = 0; System.out.println("LATTICE: NODE " + cur_node + " START"); cur_node++; for (j = 0; j < numLabels; j++) { memory[0][j].first = Vi.vect[j]; memory[0][j].second = j; System.out.println("LATTICE: NODE " + cur_node + " 0:" + j); node_label[0][j] = cur_node; cur_node++; System.out.println("LATTICE: EDGE 0:"+ j + ":0 " + (0) + " " + (cur_node -1)+ " " + Vi.vect[j] ); } Set best_beam; if (((Observation)seq.get(0)).knownWord) { best_beam = construct_beam(memory[0], beam_size); } else { best_beam = full_beam; } // scaling for the first position //divide(memory[0], sum(memory[0])); // the main loop for (i = 1; i < seqLen; i++) { // compute Vi at the position i //computeVi(seq, i, Vi, true); computeVi(seq, i, Vi, false); // for all possible labels at the position i for (j = 0; j < numLabels; j++) { memory[i][j].first = 0.0; memory[i][j].second = 0; //System.out.println("LATTICE: NODE " + cur_node + " "+ i +":"+ j ); //node_label[i][j] = cur_node; //cur_node++; // find the maximal value and its index and store them in memory // for later tracing back to find the best path for (k = 0; k < numLabels; k++) { if (!best_beam.contains(k)) continue; //double tempVal = memory[i - 1][k].first * //Mi.mtrx[k][j] * Vi.vect[j]; double tempVal = memory[i - 1][k].first + Mi.mtrx[k][j] + Vi.vect[j]; //System.out.println("LATTICE: EDGE " + i+":"+j+":"+k + " "+ node_label[i - 1][k] + " " + (cur_node-1) + " "+(Mi.mtrx[k][j] + Vi.vect[j])); //System.out.println(i); //System.out.println(j); //System.out.println(k); //System.out.println(Mi.mtrx[k][j] * Vi.vect[j]); if (tempVal > memory[i][j].first) { memory[i][j].first = tempVal; memory[i][j].second = k; } } } // scaling for memory at position i //divide(memory[i], sum(memory[i])); Set last_beam = new HashSet(best_beam); System.out.println("Known " + ((Observation)seq.get(i)).knownWord + ((Observation)seq.get(i)).originalData); if (((Observation)seq.get(i)).knownWord) { best_beam = construct_beam(memory[i], beam_size); } else { best_beam = full_beam; } for (j = 0; j < numLabels; j++) { if (!best_beam.contains(j)) continue; System.out.println("LATTICE: NODE " + cur_node + " "+ i +":"+ j ); node_label[i][j] = cur_node; cur_node++; // find the maximal value and its index and store them in memory // for later tracing back to find the best path for (k = 0; k < numLabels; k++) { if (!last_beam.contains(k)) continue; //double tempVal = memory[i - 1][k].first * //Mi.mtrx[k][j] * Vi.vect[j]; //double tempVal = memory[i - 1][k].first + // Mi.mtrx[k][j] + Vi.vect[j]; System.out.println("LATTICE: EDGE " + i+":"+j+":"+k + " "+ node_label[i - 1][k] + " " + (cur_node-1) + " "+(Mi.mtrx[k][j] + Vi.vect[j])); //System.out.println(i); //System.out.println(j); //System.out.println(k); //System.out.println(Mi.mtrx[k][j] * Vi.vect[j]); //if (tempVal > memory[i][j].first) { // memory[i][j].first = tempVal; // memory[i][j].second = k; //} } } } // viterbi backtrack to find the best label path int maxIdx = findMax(memory[seqLen - 1]); System.out.println("LATTICE: NODE " + cur_node + " final" ); cur_node++; for (i = 0; i < numLabels; i++) { if (!best_beam.contains(i)) continue; System.out.println("LATTICE: EDGE " + "last" + " "+ node_label[seqLen-1][i] + " " + (cur_node-1) + " 0.0"); } System.out.println("LATTICE: END"); System.out.println("Final score"); System.out.println(memory[seqLen - 1][maxIdx].first); ((Observation)seq.get(seqLen - 1)).modelLabel = maxIdx; for (i = seqLen - 2; i >= 0; i--) { ((Observation)seq.get(i)).modelLabel = memory[i + 1][maxIdx].second; maxIdx = ((Observation)seq.get(i)).modelLabel; } } } // end of class Viterbi