package mstparser;
public class KBestParseForest {
public static int rootType;
public ParseForestItem[][][][][] chart;
private String[] sent, pos;
private int start, end;
private int K;
public KBestParseForest(int start, int end, DependencyInstance inst, int K) {
this.K = K;
chart = new ParseForestItem[end + 1][end + 1][2][2][K];
this.start = start;
this.end = end;
this.sent = inst.forms;
this.pos = inst.postags;
}
public boolean add(int s, int type, int dir, double score, FeatureVector fv) {
boolean added = false;
if (chart[s][s][dir][0][0] == null) {
for (int i = 0; i < K; i++) {
chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, Double.NEGATIVE_INFINITY, null);
}
}
if (chart[s][s][dir][0][K - 1].prob > score) {
return false;
}
for (int i = 0; i < K; i++) {
if (chart[s][s][dir][0][i].prob < score) {
ParseForestItem tmp = chart[s][s][dir][0][i];
chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, score, fv);
for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) {
ParseForestItem tmp1 = chart[s][s][dir][0][j];
chart[s][s][dir][0][j] = tmp;
tmp = tmp1;
}
added = true;
break;
}
}
return added;
}
public boolean add(int s, int r, int t, int type,
int dir, int comp, double score,
FeatureVector fv,
ParseForestItem p1, ParseForestItem p2) {
boolean added = false;
if (chart[s][t][dir][comp][0] == null) {
for (int i = 0; i < K; i++) {
chart[s][t][dir][comp][i] =
new ParseForestItem(s, r, t, type, dir, comp, Double.NEGATIVE_INFINITY, null, null, null);
}
}
if (chart[s][t][dir][comp][K - 1].prob > score) {
return false;
}
for (int i = 0; i < K; i++) {
if (chart[s][t][dir][comp][i].prob < score) {
ParseForestItem tmp = chart[s][t][dir][comp][i];
chart[s][t][dir][comp][i] =
new ParseForestItem(s, r, t, type, dir, comp, score, fv, p1, p2);
for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) {
ParseForestItem tmp1 = chart[s][t][dir][comp][j];
chart[s][t][dir][comp][j] = tmp;
tmp = tmp1;
}
added = true;
break;
}
}
return added;
}
public double getProb(int s, int t, int dir, int comp) {
return getProb(s, t, dir, comp, 0);
}
public double getProb(int s, int t, int dir, int comp, int i) {
if (chart[s][t][dir][comp][i] != null) {
return chart[s][t][dir][comp][i].prob;
}
return Double.NEGATIVE_INFINITY;
}
public double[] getProbs(int s, int t, int dir, int comp) {
double[] result = new double[K];
for (int i = 0; i < K; i++) {
result[i] =
chart[s][t][dir][comp][i] != null
? chart[s][t][dir][comp][i].prob
: Double.NEGATIVE_INFINITY;
}
return result;
}
public ParseForestItem getItem(int s, int t, int dir, int comp) {
return getItem(s, t, dir, comp, 0);
}
public ParseForestItem getItem(int s, int t, int dir, int comp, int k) {
if (chart[s][t][dir][comp][k] != null) {
return chart[s][t][dir][comp][k];
}
return null;
}
public ParseForestItem[] getItems(int s, int t, int dir, int comp) {
if (chart[s][t][dir][comp][0] != null) {
return chart[s][t][dir][comp];
}
return null;
}
public Object[] getBestParse() {
Object[] d = new Object[2];
d[0] = getFeatureVector(chart[0][end][0][0][0]);
d[1] = getDepString(chart[0][end][0][0][0]);
return d;
}
public Object[][] getBestParses() {
Object[][] d = new Object[K][2];
for (int k = 0; k < K; k++) {
if (chart[0][end][0][0][k].prob != Double.NEGATIVE_INFINITY) {
d[k][0] = getFeatureVector(chart[0][end][0][0][k]);
d[k][1] = getDepString(chart[0][end][0][0][k]);
} else {
d[k][0] = null;
d[k][1] = null;
}
}
return d;
}
public FeatureVector getFeatureVector(ParseForestItem pfi) {
if (pfi.left == null) {
return pfi.fv;
}
return cat(pfi.fv, cat(getFeatureVector(pfi.left), getFeatureVector(pfi.right)));
}
public String getDepString(ParseForestItem pfi) {
if (pfi.left == null) {
return "";
}
if (pfi.comp == 0) {
return (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim();
} else if (pfi.dir == 0) {
return ((getDepString(pfi.left) + " " + getDepString(pfi.right)).trim() + " "
+ pfi.s + "|" + pfi.t + ":" + pfi.type).trim();
} else {
return (pfi.t + "|" + pfi.s + ":" + pfi.type + " "
+ (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim()).trim();
}
}
public FeatureVector cat(FeatureVector fv1, FeatureVector fv2) {
return fv1.cat(fv2);
}
// returns pairs of indeces and -1,-1 if < K pairs
public int[][] getKBestPairs(ParseForestItem[] items1, ParseForestItem[] items2) {
// in this case K = items1.length
boolean[][] beenPushed = new boolean[K][K];
int[][] result = new int[K][2];
for (int i = 0; i < K; i++) {
result[i][0] = -1;
result[i][1] = -1;
}
if (items1 == null || items2 == null || items1[0] == null || items2[0] == null) {
return result;
}
BinaryHeap heap = new BinaryHeap(K + 1);
int n = 0;
ValueIndexPair vip = new ValueIndexPair(items1[0].prob + items2[0].prob, 0, 0);
heap.add(vip);
beenPushed[0][0] = true;
while (n < K) {
vip = heap.removeMax();
if (vip.val == Double.NEGATIVE_INFINITY) {
break;
}
result[n][0] = vip.i1;
result[n][1] = vip.i2;
n++;
if (n >= K) {
break;
}
if (!beenPushed[vip.i1 + 1][vip.i2]) {
heap.add(new ValueIndexPair(items1[vip.i1 + 1].prob + items2[vip.i2].prob, vip.i1 + 1, vip.i2));
beenPushed[vip.i1 + 1][vip.i2] = true;
}
if (!beenPushed[vip.i1][vip.i2 + 1]) {
heap.add(new ValueIndexPair(items1[vip.i1].prob + items2[vip.i2 + 1].prob, vip.i1, vip.i2 + 1));
beenPushed[vip.i1][vip.i2 + 1] = true;
}
}
return result;
}
}
class ValueIndexPair {
public double val;
public int i1, i2;
public ValueIndexPair(double val, int i1, int i2) {
this.val = val;
this.i1 = i1;
this.i2 = i2;
}
public int compareTo(ValueIndexPair other) {
if (val < other.val) {
return -1;
}
if (val > other.val) {
return 1;
}
return 0;
}
}
// Max Heap
// We know that never more than K elements on Heap
class BinaryHeap {
private int DEFAULT_CAPACITY;
private int currentSize;
private ValueIndexPair[] theArray;
public BinaryHeap(int def_cap) {
DEFAULT_CAPACITY = def_cap;
theArray = new ValueIndexPair[DEFAULT_CAPACITY + 1];
// theArray[0] serves as dummy parent for root (who is at 1)
// "largest" is guaranteed to be larger than all keys in heap
theArray[0] = new ValueIndexPair(Double.POSITIVE_INFINITY, -1, -1);
currentSize = 0;
}
public ValueIndexPair getMax() {
return theArray[1];
}
private int parent(int i) {
return i / 2;
}
private int leftChild(int i) {
return 2 * i;
}
private int rightChild(int i) {
return 2 * i + 1;
}
public void add(ValueIndexPair e) {
// bubble up:
int where = currentSize + 1; // new last place
while (e.compareTo(theArray[parent(where)]) > 0) {
theArray[where] = theArray[parent(where)];
where = parent(where);
}
theArray[where] = e;
currentSize++;
}
public ValueIndexPair removeMax() {
ValueIndexPair min = theArray[1];
theArray[1] = theArray[currentSize];
currentSize--;
boolean switched = true;
// bubble down
for (int parent = 1; switched && parent < currentSize;) {
switched = false;
int leftChild = leftChild(parent);
int rightChild = rightChild(parent);
if (leftChild <= currentSize) {
// if there is a right child, see if we should bubble down there
int largerChild = leftChild;
if ((rightChild <= currentSize)
&& (theArray[rightChild].compareTo(theArray[leftChild])) > 0) {
largerChild = rightChild;
}
if (theArray[largerChild].compareTo(theArray[parent]) > 0) {
ValueIndexPair temp = theArray[largerChild];
theArray[largerChild] = theArray[parent];
theArray[parent] = temp;
parent = largerChild;
switched = true;
}
}
}
return min;
}
}