package mstparser;
import gnu.trove.TIntIntHashMap;
import java.util.ArrayList;
import java.util.Arrays;
public class DependencyDecoder {
DependencyPipe pipe;
public DependencyDecoder(DependencyPipe pipe) {
this.pipe = pipe;
}
protected int[][] getTypes(double[][][][] nt_probs, int len) {
int[][] static_types = new int[len][len];
for (int i = 0; i < len; i++) {
for (int j = 0; j < len; j++) {
if (i == j) {
static_types[i][j] = 0;
continue;
}
int wh = -1;
double best = Double.NEGATIVE_INFINITY;
for (int t = 0; t < pipe.types.length; t++) {
double score;
if (i < j) {
score = nt_probs[i][t][0][1] + nt_probs[j][t][0][0];
} else {
score = nt_probs[i][t][1][1] + nt_probs[j][t][1][0];
}
if (score > best) {
wh = t;
best = score;
}
}
static_types[i][j] = wh;
}
}
return static_types;
}
// static type for each edge: run time O(n^3 + Tn^2) T is number of types
public Object[][] decodeProjective(DependencyInstance inst,
FeatureVector[][][] fvs,
double[][][] probs,
FeatureVector[][][][] nt_fvs,
double[][][][] nt_probs, int K) {
String[] forms = inst.forms;
String[] pos = inst.postags;
int[][] static_types = null;
if (pipe.labeled) {
if (!pipe.separateLab) {
static_types = getTypes(nt_probs, forms.length);
}
}
KBestParseForest pf = new KBestParseForest(0, forms.length - 1, inst, K);
for (int s = 0; s < forms.length; s++) {
pf.add(s, -1, 0, 0.0, new FeatureVector());
pf.add(s, -1, 1, 0.0, new FeatureVector());
}
for (int j = 1; j < forms.length; j++) {
for (int s = 0; s < forms.length && s + j < forms.length; s++) {
int t = s + j;
FeatureVector prodFV_st = fvs[s][t][0];
FeatureVector prodFV_ts = fvs[s][t][1];
double prodProb_st = probs[s][t][0];
double prodProb_ts = probs[s][t][1];
int type1 = (pipe.labeled && !pipe.separateLab) ? static_types[s][t] : 0;
int type2 = (pipe.labeled && !pipe.separateLab) ? static_types[t][s] : 0;
FeatureVector nt_fv_s_01 = nt_fvs[s][type1][0][1];
FeatureVector nt_fv_s_10 = nt_fvs[s][type2][1][0];
FeatureVector nt_fv_t_00 = nt_fvs[t][type1][0][0];
FeatureVector nt_fv_t_11 = nt_fvs[t][type2][1][1];
double nt_prob_s_01 = nt_probs[s][type1][0][1];
double nt_prob_s_10 = nt_probs[s][type2][1][0];
double nt_prob_t_00 = nt_probs[t][type1][0][0];
double nt_prob_t_11 = nt_probs[t][type2][1][1];
double prodProb = 0.0;
for (int r = s; r <= t; r++) {
/**
* first is direction, second is complete
*/
/**
* _s means s is the parent
*/
if (r != t) {
ParseForestItem[] b1 = pf.getItems(s, r, 0, 0);
ParseForestItem[] c1 = pf.getItems(r + 1, t, 1, 0);
if (b1 != null && c1 != null) {
int[][] pairs = pf.getKBestPairs(b1, c1);
for (int k = 0; k < pairs.length; k++) {
if (pairs[k][0] == -1 || pairs[k][1] == -1) {
break;
}
int comp1 = pairs[k][0];
int comp2 = pairs[k][1];
double bc = b1[comp1].prob + c1[comp2].prob;
double prob_fin = bc + prodProb_st;
FeatureVector fv_fin = prodFV_st;
if(pipe.labeled && !pipe.separateLab) {
fv_fin = nt_fv_s_01.cat(nt_fv_t_00.cat(fv_fin));
prob_fin += nt_prob_s_01 + nt_prob_t_00;
}
pf.add(s, r, t, type1, 0, 1, prob_fin, fv_fin, b1[comp1], c1[comp2]);
prob_fin = bc + prodProb_ts;
fv_fin = prodFV_ts;
if(pipe.labeled && !pipe.separateLab) {
fv_fin = nt_fv_t_11.cat(nt_fv_s_10.cat(fv_fin));
prob_fin += nt_prob_t_11 + nt_prob_s_10;
}
pf.add(s, r, t, type2, 1, 1, prob_fin, fv_fin, b1[comp1], c1[comp2]);
}
}
}
}
for (int r = s; r <= t; r++) {
if (r != s) {
ParseForestItem[] b1 = pf.getItems(s, r, 0, 1);
ParseForestItem[] c1 = pf.getItems(r, t, 0, 0);
if (b1 != null && c1 != null) {
int[][] pairs = pf.getKBestPairs(b1, c1);
for (int k = 0; k < pairs.length; k++) {
if (pairs[k][0] == -1 || pairs[k][1] == -1) {
break;
}
int comp1 = pairs[k][0];
int comp2 = pairs[k][1];
double bc = b1[comp1].prob + c1[comp2].prob;
if (!pf.add(s, r, t, -1, 0, 0, bc,
new FeatureVector(),
b1[comp1], c1[comp2])) {
break;
}
}
}
}
if (r != t) {
ParseForestItem[] b1 = pf.getItems(s, r, 1, 0);
ParseForestItem[] c1 = pf.getItems(r, t, 1, 1);
if (b1 != null && c1 != null) {
int[][] pairs = pf.getKBestPairs(b1, c1);
for (int k = 0; k < pairs.length; k++) {
if (pairs[k][0] == -1 || pairs[k][1] == -1) {
break;
}
int comp1 = pairs[k][0];
int comp2 = pairs[k][1];
double bc = b1[comp1].prob + c1[comp2].prob;
if (!pf.add(s, r, t, -1, 1, 0, bc,
new FeatureVector(), b1[comp1], c1[comp2])) {
break;
}
}
}
}
}
}
}
return pf.getBestParses();
}
public Object[][] decodeNonProjective(DependencyInstance inst,
FeatureVector[][][] fvs,
double[][][] probs,
FeatureVector[][][][] nt_fvs,
double[][][][] nt_probs, int K) {
String[] pos = inst.postags;
int numWords = inst.length();
int[][] oldI = new int[numWords][numWords];
int[][] oldO = new int[numWords][numWords];
double[][] scoreMatrix = new double[numWords][numWords];
double[][] orig_scoreMatrix = new double[numWords][numWords];
boolean[] curr_nodes = new boolean[numWords];
TIntIntHashMap[] reps = new TIntIntHashMap[numWords];
int[][] static_types = null;
if (pipe.labeled) {
if (!pipe.separateLab) { // afm 06-03-08
static_types = getTypes(nt_probs, pos.length);
}
}
for (int i = 0; i < numWords; i++) {
curr_nodes[i] = true;
reps[i] = new TIntIntHashMap();
reps[i].put(i, 0);
for (int j = 0; j < numWords; j++) {
// score of edge (i,j) i --> j
scoreMatrix[i][j] = probs[i < j ? i : j][i < j ? j : i][i < j ? 0 : 1]
+ ((pipe.labeled && !pipe.separateLab) ? nt_probs[i][static_types[i][j]][i < j ? 0 : 1][1]
+ nt_probs[j][static_types[i][j]][i < j ? 0 : 1][0]
: 0.0);
orig_scoreMatrix[i][j] = probs[i < j ? i : j][i < j ? j : i][i < j ? 0 : 1]
+ ((pipe.labeled && !pipe.separateLab) ? nt_probs[i][static_types[i][j]][i < j ? 0 : 1][1]
+ nt_probs[j][static_types[i][j]][i < j ? 0 : 1][0]
: 0.0);
oldI[i][j] = i;
oldO[i][j] = j;
if (i == j || j == 0) {
continue; // no self loops of i --> 0
}
}
}
TIntIntHashMap final_edges = chuLiuEdmonds(scoreMatrix, curr_nodes, oldI, oldO, false, new TIntIntHashMap(), reps);
int[] par = new int[numWords];
int[] ns = final_edges.keys();
for (int i = 0; i < ns.length; i++) {
int ch = ns[i];
int pr = final_edges.get(ns[i]);
par[ch] = pr;
}
int[] n_par = getKChanges(par, orig_scoreMatrix, Math.min(K, par.length));
int new_k = 1;
for (int i = 0; i < n_par.length; i++) {
if (n_par[i] > -1) {
new_k++;
}
}
// Create Feature Vectors;
int[][] fin_par = new int[new_k][numWords];
FeatureVector[][] fin_fv = new FeatureVector[new_k][numWords];
fin_par[0] = par;
int c = 1;
for (int i = 0; i < n_par.length; i++) {
if (n_par[i] > -1) {
int[] t_par = new int[par.length];
System.arraycopy(par, 0, t_par, 0, t_par.length);
t_par[i] = n_par[i];
fin_par[c] = t_par;
c++;
}
}
for (int k = 0; k < fin_par.length; k++) {
for (int i = 0; i < fin_par[k].length; i++) {
int ch = i;
int pr = fin_par[k][i];
if (pr != -1) {
fin_fv[k][ch] = fvs[ch < pr ? ch : pr][ch < pr ? pr : ch][ch < pr ? 1 : 0];
if (pipe.labeled) {
fin_fv[k][ch] =
fin_fv[k][ch].cat(nt_fvs[ch][static_types[pr][ch]][ch < pr ? 1 : 0][0]);
fin_fv[k][ch] =
fin_fv[k][ch].cat(nt_fvs[pr][static_types[pr][ch]][ch < pr ? 1 : 0][1]);
}
} else {
fin_fv[k][ch] = new FeatureVector();
}
}
}
FeatureVector[] fin = new FeatureVector[new_k];
String[] result = new String[new_k];
for (int k = 0; k < fin.length; k++) {
fin[k] = new FeatureVector();
for (int i = 1; i < fin_fv[k].length; i++) {
fin[k] = fin_fv[k][i].cat(fin[k]);
}
result[k] = "";
for (int i = 1; i < par.length; i++) {
result[k] += fin_par[k][i] + "|" + i + ((pipe.labeled && !pipe.separateLab) ? ":" + static_types[fin_par[k][i]][i] : ":0") + " ";
}
}
// create d.
Object[][] d = new Object[new_k][2];
for (int k = 0; k < new_k; k++) {
d[k][0] = fin[k];
d[k][1] = result[k].trim();
}
return d;
}
private int[] getKChanges(int[] par, double[][] scoreMatrix, int K) {
int[] result = new int[par.length];
int[] n_par = new int[par.length];
double[] n_score = new double[par.length];
for (int i = 0; i < par.length; i++) {
result[i] = -1;
n_par[i] = -1;
n_score[i] = Double.NEGATIVE_INFINITY;
}
boolean[][] isChild = calcChilds(par);
for (int i = 1; i < n_par.length; i++) {
double max = Double.NEGATIVE_INFINITY;
int wh = -1;
for (int j = 0; j < n_par.length; j++) {
if (i == j || par[i] == j || isChild[i][j]) {
continue;
}
if (scoreMatrix[j][i] > max) {
max = scoreMatrix[j][i];
wh = j;
}
}
n_par[i] = wh;
n_score[i] = max;
}
for (int k = 0; k < K; k++) {
double max = Double.NEGATIVE_INFINITY;
int wh = -1;
int whI = -1;
for (int i = 0; i < n_par.length; i++) {
if (n_par[i] == -1) {
continue;
}
double score = scoreMatrix[n_par[i]][i];
if (score > max) {
max = score;
whI = i;
wh = n_par[i];
}
}
if (max == Double.NEGATIVE_INFINITY) {
break;
}
result[whI] = wh;
n_par[whI] = -1;
}
return result;
}
private boolean[][] calcChilds(int[] par) {
boolean[][] isChild = new boolean[par.length][par.length];
for (int i = 1; i < par.length; i++) {
int l = par[i];
while (l != -1) {
isChild[l][i] = true;
l = par[l];
}
}
return isChild;
}
private static TIntIntHashMap chuLiuEdmonds(double[][] scoreMatrix, boolean[] curr_nodes,
int[][] oldI, int[][] oldO, boolean print,
TIntIntHashMap final_edges, TIntIntHashMap[] reps) {
// need to construct for each node list of nodes they represent (here only!)
int[] par = new int[curr_nodes.length];
int numWords = curr_nodes.length;
// create best graph
par[0] = -1;
for (int i = 1; i < par.length; i++) {
// only interested in current nodes
if (!curr_nodes[i]) {
continue;
}
double maxScore = scoreMatrix[0][i];
par[i] = 0;
for (int j = 0; j < par.length; j++) {
if (j == i) {
continue;
}
if (!curr_nodes[j]) {
continue;
}
double newScore = scoreMatrix[j][i];
if (newScore > maxScore) {
maxScore = newScore;
par[i] = j;
}
}
}
if (print) {
DependencyParser.out.println("After init");
for (int i = 0; i < par.length; i++) {
if (curr_nodes[i]) {
DependencyParser.out.print(par[i] + "|" + i + " ");
}
}
DependencyParser.out.println();
}
//Find a cycle
ArrayList cycles = new ArrayList();
boolean[] added = new boolean[numWords];
for (int i = 0; i < numWords && cycles.isEmpty(); i++) {
// if I have already considered this or
// This is not a valid node (i.e. has been contracted)
if (added[i] || !curr_nodes[i]) {
continue;
}
added[i] = true;
TIntIntHashMap cycle = new TIntIntHashMap();
cycle.put(i, 0);
int l = i;
while (true) {
if (par[l] == -1) {
added[l] = true;
break;
}
if (cycle.contains(par[l])) {
cycle = new TIntIntHashMap();
int lorg = par[l];
cycle.put(lorg, par[lorg]);
added[lorg] = true;
int l1 = par[lorg];
while (l1 != lorg) {
cycle.put(l1, par[l1]);
added[l1] = true;
l1 = par[l1];
}
cycles.add(cycle);
break;
}
cycle.put(l, 0);
l = par[l];
if (added[l] && !cycle.contains(l)) {
break;
}
added[l] = true;
}
}
// get all edges and return them
if (cycles.isEmpty()) {
//DependencyParser.out.println("TREE:");
for (int i = 0; i < par.length; i++) {
if (!curr_nodes[i]) {
continue;
}
if (par[i] != -1) {
int pr = oldI[par[i]][i];
int ch = oldO[par[i]][i];
final_edges.put(ch, pr);
//DependencyParser.out.print(pr+"|"+ch + " ");
} else {
final_edges.put(0, -1);
}
}
//DependencyParser.out.println();
return final_edges;
}
int max_cyc = 0;
int wh_cyc = 0;
for (int i = 0; i < cycles.size(); i++) {
TIntIntHashMap cycle = (TIntIntHashMap) cycles.get(i);
if (cycle.size() > max_cyc) {
max_cyc = cycle.size();
wh_cyc = i;
}
}
TIntIntHashMap cycle = (TIntIntHashMap) cycles.get(wh_cyc);
int[] cyc_nodes = cycle.keys();
int rep = cyc_nodes[0];
if (print) {
DependencyParser.out.println("Found Cycle");
for (int i = 0; i < cyc_nodes.length; i++) {
DependencyParser.out.print(cyc_nodes[i] + " ");
}
DependencyParser.out.println();
}
double cyc_weight = 0.0;
for (int j = 0; j < cyc_nodes.length; j++) {
cyc_weight += scoreMatrix[par[cyc_nodes[j]]][cyc_nodes[j]];
}
for (int i = 0; i < numWords; i++) {
if (!curr_nodes[i] || cycle.contains(i)) {
continue;
}
double max1 = Double.NEGATIVE_INFINITY;
int wh1 = -1;
double max2 = Double.NEGATIVE_INFINITY;
int wh2 = -1;
for (int j = 0; j < cyc_nodes.length; j++) {
int j1 = cyc_nodes[j];
if (scoreMatrix[j1][i] > max1) {
max1 = scoreMatrix[j1][i];
wh1 = j1;//oldI[j1][i];
}
// cycle weight + new edge - removal of old
double scr = cyc_weight + scoreMatrix[i][j1] - scoreMatrix[par[j1]][j1];
if (scr > max2) {
max2 = scr;
wh2 = j1;//oldO[i][j1];
}
}
scoreMatrix[rep][i] = max1;
oldI[rep][i] = oldI[wh1][i];//wh1;
oldO[rep][i] = oldO[wh1][i];//oldO[wh1][i];
scoreMatrix[i][rep] = max2;
oldO[i][rep] = oldO[i][wh2];//wh2;
oldI[i][rep] = oldI[i][wh2];//oldI[i][wh2];
}
TIntIntHashMap[] rep_cons = new TIntIntHashMap[cyc_nodes.length];
for (int i = 0; i < cyc_nodes.length; i++) {
rep_cons[i] = new TIntIntHashMap();
int[] keys = reps[cyc_nodes[i]].keys();
Arrays.sort(keys);
if (print) {
DependencyParser.out.print(cyc_nodes[i] + ": ");
}
for (int j = 0; j < keys.length; j++) {
rep_cons[i].put(keys[j], 0);
if (print) {
DependencyParser.out.print(keys[j] + " ");
}
}
if (print) {
DependencyParser.out.println();
}
}
// don't consider not representative nodes
// these nodes have been folded
for (int i = 1; i < cyc_nodes.length; i++) {
curr_nodes[cyc_nodes[i]] = false;
int[] keys = reps[cyc_nodes[i]].keys();
for (int j = 0; j < keys.length; j++) {
reps[rep].put(keys[j], 0);
}
}
chuLiuEdmonds(scoreMatrix, curr_nodes, oldI, oldO, print, final_edges, reps);
// check each node in cycle, if one of its representatives
// is a key in the final_edges, it is the one.
int wh = -1;
boolean found = false;
for (int i = 0; i < rep_cons.length && !found; i++) {
int[] keys = rep_cons[i].keys();
for (int j = 0; j < keys.length && !found; j++) {
if (final_edges.contains(keys[j])) {
wh = cyc_nodes[i];
found = true;
}
}
}
int l = par[wh];
while (l != wh) {
int ch = oldO[par[l]][l];
int pr = oldI[par[l]][l];
final_edges.put(ch, pr);
l = par[l];
}
if (print) {
int[] keys = final_edges.keys();
Arrays.sort(keys);
for (int i = 0; i < keys.length; i++) {
DependencyParser.out.print(final_edges.get(keys[i]) + "|" + keys[i] + " ");
}
DependencyParser.out.println();
}
return final_edges;
}
}