package mstparser; public class KBestParseForest { public static int rootType; public ParseForestItem[][][][][] chart; private String[] sent, pos; private int start, end; private int K; public KBestParseForest(int start, int end, DependencyInstance inst, int K) { this.K = K; chart = new ParseForestItem[end + 1][end + 1][2][2][K]; this.start = start; this.end = end; this.sent = inst.forms; this.pos = inst.postags; } public boolean add(int s, int type, int dir, double score, FeatureVector fv) { boolean added = false; if (chart[s][s][dir][0][0] == null) { for (int i = 0; i < K; i++) { chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, Double.NEGATIVE_INFINITY, null); } } if (chart[s][s][dir][0][K - 1].prob > score) { return false; } for (int i = 0; i < K; i++) { if (chart[s][s][dir][0][i].prob < score) { ParseForestItem tmp = chart[s][s][dir][0][i]; chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, score, fv); for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][s][dir][0][j]; chart[s][s][dir][0][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public boolean add(int s, int r, int t, int type, int dir, int comp, double score, FeatureVector fv, ParseForestItem p1, ParseForestItem p2) { boolean added = false; if (chart[s][t][dir][comp][0] == null) { for (int i = 0; i < K; i++) { chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, Double.NEGATIVE_INFINITY, null, null, null); } } if (chart[s][t][dir][comp][K - 1].prob > score) { return false; } for (int i = 0; i < K; i++) { if (chart[s][t][dir][comp][i].prob < score) { ParseForestItem tmp = chart[s][t][dir][comp][i]; chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, score, fv, p1, p2); for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][t][dir][comp][j]; chart[s][t][dir][comp][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public double getProb(int s, int t, int dir, int comp) { return getProb(s, t, dir, comp, 0); } public double getProb(int s, int t, int dir, int comp, int i) { if (chart[s][t][dir][comp][i] != null) { return chart[s][t][dir][comp][i].prob; } return Double.NEGATIVE_INFINITY; } public double[] getProbs(int s, int t, int dir, int comp) { double[] result = new double[K]; for (int i = 0; i < K; i++) { result[i] = chart[s][t][dir][comp][i] != null ? chart[s][t][dir][comp][i].prob : Double.NEGATIVE_INFINITY; } return result; } public ParseForestItem getItem(int s, int t, int dir, int comp) { return getItem(s, t, dir, comp, 0); } public ParseForestItem getItem(int s, int t, int dir, int comp, int k) { if (chart[s][t][dir][comp][k] != null) { return chart[s][t][dir][comp][k]; } return null; } public ParseForestItem[] getItems(int s, int t, int dir, int comp) { if (chart[s][t][dir][comp][0] != null) { return chart[s][t][dir][comp]; } return null; } public Object[] getBestParse() { Object[] d = new Object[2]; d[0] = getFeatureVector(chart[0][end][0][0][0]); d[1] = getDepString(chart[0][end][0][0][0]); return d; } public Object[][] getBestParses() { Object[][] d = new Object[K][2]; for (int k = 0; k < K; k++) { if (chart[0][end][0][0][k].prob != Double.NEGATIVE_INFINITY) { d[k][0] = getFeatureVector(chart[0][end][0][0][k]); d[k][1] = getDepString(chart[0][end][0][0][k]); } else { d[k][0] = null; d[k][1] = null; } } return d; } public FeatureVector getFeatureVector(ParseForestItem pfi) { if (pfi.left == null) { return pfi.fv; } return cat(pfi.fv, cat(getFeatureVector(pfi.left), getFeatureVector(pfi.right))); } public String getDepString(ParseForestItem pfi) { if (pfi.left == null) { return ""; } if (pfi.comp == 0) { return (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim(); } else if (pfi.dir == 0) { return ((getDepString(pfi.left) + " " + getDepString(pfi.right)).trim() + " " + pfi.s + "|" + pfi.t + ":" + pfi.type).trim(); } else { return (pfi.t + "|" + pfi.s + ":" + pfi.type + " " + (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim()).trim(); } } public FeatureVector cat(FeatureVector fv1, FeatureVector fv2) { return fv1.cat(fv2); } // returns pairs of indeces and -1,-1 if < K pairs public int[][] getKBestPairs(ParseForestItem[] items1, ParseForestItem[] items2) { // in this case K = items1.length boolean[][] beenPushed = new boolean[K][K]; int[][] result = new int[K][2]; for (int i = 0; i < K; i++) { result[i][0] = -1; result[i][1] = -1; } if (items1 == null || items2 == null || items1[0] == null || items2[0] == null) { return result; } BinaryHeap heap = new BinaryHeap(K + 1); int n = 0; ValueIndexPair vip = new ValueIndexPair(items1[0].prob + items2[0].prob, 0, 0); heap.add(vip); beenPushed[0][0] = true; while (n < K) { vip = heap.removeMax(); if (vip.val == Double.NEGATIVE_INFINITY) { break; } result[n][0] = vip.i1; result[n][1] = vip.i2; n++; if (n >= K) { break; } if (!beenPushed[vip.i1 + 1][vip.i2]) { heap.add(new ValueIndexPair(items1[vip.i1 + 1].prob + items2[vip.i2].prob, vip.i1 + 1, vip.i2)); beenPushed[vip.i1 + 1][vip.i2] = true; } if (!beenPushed[vip.i1][vip.i2 + 1]) { heap.add(new ValueIndexPair(items1[vip.i1].prob + items2[vip.i2 + 1].prob, vip.i1, vip.i2 + 1)); beenPushed[vip.i1][vip.i2 + 1] = true; } } return result; } } class ValueIndexPair { public double val; public int i1, i2; public ValueIndexPair(double val, int i1, int i2) { this.val = val; this.i1 = i1; this.i2 = i2; } public int compareTo(ValueIndexPair other) { if (val < other.val) { return -1; } if (val > other.val) { return 1; } return 0; } } // Max Heap // We know that never more than K elements on Heap class BinaryHeap { private int DEFAULT_CAPACITY; private int currentSize; private ValueIndexPair[] theArray; public BinaryHeap(int def_cap) { DEFAULT_CAPACITY = def_cap; theArray = new ValueIndexPair[DEFAULT_CAPACITY + 1]; // theArray[0] serves as dummy parent for root (who is at 1) // "largest" is guaranteed to be larger than all keys in heap theArray[0] = new ValueIndexPair(Double.POSITIVE_INFINITY, -1, -1); currentSize = 0; } public ValueIndexPair getMax() { return theArray[1]; } private int parent(int i) { return i / 2; } private int leftChild(int i) { return 2 * i; } private int rightChild(int i) { return 2 * i + 1; } public void add(ValueIndexPair e) { // bubble up: int where = currentSize + 1; // new last place while (e.compareTo(theArray[parent(where)]) > 0) { theArray[where] = theArray[parent(where)]; where = parent(where); } theArray[where] = e; currentSize++; } public ValueIndexPair removeMax() { ValueIndexPair min = theArray[1]; theArray[1] = theArray[currentSize]; currentSize--; boolean switched = true; // bubble down for (int parent = 1; switched && parent < currentSize;) { switched = false; int leftChild = leftChild(parent); int rightChild = rightChild(parent); if (leftChild <= currentSize) { // if there is a right child, see if we should bubble down there int largerChild = leftChild; if ((rightChild <= currentSize) && (theArray[rightChild].compareTo(theArray[leftChild])) > 0) { largerChild = rightChild; } if (theArray[largerChild].compareTo(theArray[parent]) > 0) { ValueIndexPair temp = theArray[largerChild]; theArray[largerChild] = theArray[parent]; theArray[parent] = temp; parent = largerChild; switched = true; } } } return min; } }