package mstparser; public class KBestParseForest2O { private ParseForestItem[][][][][] chart; private String[] sent, pos; private int start, end; private int K; public KBestParseForest2O(int start, int end, DependencyInstance inst, int K) { this.K = K; chart = new ParseForestItem[end + 1][end + 1][2][3][K]; this.start = start; this.end = end; this.sent = inst.forms; this.pos = inst.postags; } public boolean add(int s, int type, int dir, double score, FeatureVector fv) { boolean added = false; if (chart[s][s][dir][0][0] == null) { for (int i = 0; i < K; i++) { chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, Double.NEGATIVE_INFINITY, null); } } if (chart[s][s][dir][0][K - 1].prob > score) { return false; } for (int i = 0; i < K; i++) { if (chart[s][s][dir][0][i].prob < score) { ParseForestItem tmp = chart[s][s][dir][0][i]; chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, score, fv); for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][s][dir][0][j]; chart[s][s][dir][0][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public boolean add(int s, int r, int t, int type, int dir, int comp, double score, FeatureVector fv, ParseForestItem p1, ParseForestItem p2) { boolean added = false; if (chart[s][t][dir][comp][0] == null) { for (int i = 0; i < K; i++) { chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, Double.NEGATIVE_INFINITY, null, null, null); } } if (chart[s][t][dir][comp][K - 1].prob > score) { return false; } for (int i = 0; i < K; i++) { if (chart[s][t][dir][comp][i].prob < score) { ParseForestItem tmp = chart[s][t][dir][comp][i]; chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, score, fv, p1, p2); for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][t][dir][comp][j]; chart[s][t][dir][comp][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public double getProb(int s, int t, int dir, int comp) { return getProb(s, t, dir, comp, 0); } public double getProb(int s, int t, int dir, int comp, int i) { if (chart[s][t][dir][comp][i] != null) { return chart[s][t][dir][comp][i].prob; } return Double.NEGATIVE_INFINITY; } public double[] getProbs(int s, int t, int dir, int comp) { double[] result = new double[K]; for (int i = 0; i < K; i++) { result[i] = chart[s][t][dir][comp][i] != null ? chart[s][t][dir][comp][i].prob : Double.NEGATIVE_INFINITY; } return result; } public ParseForestItem getItem(int s, int t, int dir, int comp) { return getItem(s, t, dir, comp, 0); } public ParseForestItem getItem(int s, int t, int dir, int comp, int i) { if (chart[s][t][dir][comp][i] != null) { return chart[s][t][dir][comp][i]; } return null; } public ParseForestItem[] getItems(int s, int t, int dir, int comp) { if (chart[s][t][dir][comp][0] != null) { return chart[s][t][dir][comp]; } return null; } public Object[] getBestParse() { Object[] d = new Object[2]; d[0] = getFeatureVector(chart[0][end][0][0][0]); d[1] = getDepString(chart[0][end][0][0][0]); return d; } public Object[][] getBestParses() { Object[][] d = new Object[K][2]; for (int k = 0; k < K; k++) { if (chart[0][end][0][0][k].prob != Double.NEGATIVE_INFINITY) { d[k][0] = getFeatureVector(chart[0][end][0][0][k]); d[k][1] = getDepString(chart[0][end][0][0][k]); } else { d[k][0] = null; d[k][1] = null; } } return d; } public FeatureVector getFeatureVector(ParseForestItem pfi) { if (pfi.left == null) { return pfi.fv; } return cat(pfi.fv, cat(getFeatureVector(pfi.left), getFeatureVector(pfi.right))); } public String getDepString(ParseForestItem pfi) { if (pfi.left == null) { return ""; } if (pfi.dir == 0 && pfi.comp == 1) { return ((getDepString(pfi.left) + " " + getDepString(pfi.right)).trim() + " " + pfi.s + "|" + pfi.t + ":" + pfi.type).trim(); } else if (pfi.dir == 1 && pfi.comp == 1) { return (pfi.t + "|" + pfi.s + ":" + pfi.type + " " + (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim()).trim(); } return (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim(); } public FeatureVector cat(FeatureVector fv1, FeatureVector fv2) { return fv1.cat(fv2); } // returns pairs of indeces and -1,-1 if < K pairs public int[][] getKBestPairs(ParseForestItem[] items1, ParseForestItem[] items2) { // in this case K = items1.length boolean[][] beenPushed = new boolean[K][K]; int[][] result = new int[K][2]; for (int i = 0; i < K; i++) { result[i][0] = -1; result[i][1] = -1; } BinaryHeap heap = new BinaryHeap(K + 1); int n = 0; ValueIndexPair vip = new ValueIndexPair(items1[0].prob + items2[0].prob, 0, 0); heap.add(vip); beenPushed[0][0] = true; while (n < K) { vip = heap.removeMax(); if (vip.val == Double.NEGATIVE_INFINITY) { break; } result[n][0] = vip.i1; result[n][1] = vip.i2; n++; if (n >= K) { break; } if (!beenPushed[vip.i1 + 1][vip.i2]) { heap.add(new ValueIndexPair(items1[vip.i1 + 1].prob + items2[vip.i2].prob, vip.i1 + 1, vip.i2)); beenPushed[vip.i1 + 1][vip.i2] = true; } if (!beenPushed[vip.i1][vip.i2 + 1]) { heap.add(new ValueIndexPair(items1[vip.i1].prob + items2[vip.i2 + 1].prob, vip.i1, vip.i2 + 1)); beenPushed[vip.i1][vip.i2 + 1] = true; } } return result; } }