package mstparser; import gnu.trove.TIntIntHashMap; import java.util.ArrayList; import java.util.Arrays; public class DependencyDecoder { DependencyPipe pipe; public DependencyDecoder(DependencyPipe pipe) { this.pipe = pipe; } protected int[][] getTypes(double[][][][] nt_probs, int len) { int[][] static_types = new int[len][len]; for (int i = 0; i < len; i++) { for (int j = 0; j < len; j++) { if (i == j) { static_types[i][j] = 0; continue; } int wh = -1; double best = Double.NEGATIVE_INFINITY; for (int t = 0; t < pipe.types.length; t++) { double score; if (i < j) { score = nt_probs[i][t][0][1] + nt_probs[j][t][0][0]; } else { score = nt_probs[i][t][1][1] + nt_probs[j][t][1][0]; } if (score > best) { wh = t; best = score; } } static_types[i][j] = wh; } } return static_types; } // static type for each edge: run time O(n^3 + Tn^2) T is number of types public Object[][] decodeProjective(DependencyInstance inst, FeatureVector[][][] fvs, double[][][] probs, FeatureVector[][][][] nt_fvs, double[][][][] nt_probs, int K) { String[] forms = inst.forms; String[] pos = inst.postags; int[][] static_types = null; if (pipe.labeled) { if (!pipe.separateLab) { static_types = getTypes(nt_probs, forms.length); } } KBestParseForest pf = new KBestParseForest(0, forms.length - 1, inst, K); for (int s = 0; s < forms.length; s++) { pf.add(s, -1, 0, 0.0, new FeatureVector()); pf.add(s, -1, 1, 0.0, new FeatureVector()); } for (int j = 1; j < forms.length; j++) { for (int s = 0; s < forms.length && s + j < forms.length; s++) { int t = s + j; FeatureVector prodFV_st = fvs[s][t][0]; FeatureVector prodFV_ts = fvs[s][t][1]; double prodProb_st = probs[s][t][0]; double prodProb_ts = probs[s][t][1]; int type1 = (pipe.labeled && !pipe.separateLab) ? static_types[s][t] : 0; int type2 = (pipe.labeled && !pipe.separateLab) ? static_types[t][s] : 0; FeatureVector nt_fv_s_01 = nt_fvs[s][type1][0][1]; FeatureVector nt_fv_s_10 = nt_fvs[s][type2][1][0]; FeatureVector nt_fv_t_00 = nt_fvs[t][type1][0][0]; FeatureVector nt_fv_t_11 = nt_fvs[t][type2][1][1]; double nt_prob_s_01 = nt_probs[s][type1][0][1]; double nt_prob_s_10 = nt_probs[s][type2][1][0]; double nt_prob_t_00 = nt_probs[t][type1][0][0]; double nt_prob_t_11 = nt_probs[t][type2][1][1]; double prodProb = 0.0; for (int r = s; r <= t; r++) { /** * first is direction, second is complete */ /** * _s means s is the parent */ if (r != t) { ParseForestItem[] b1 = pf.getItems(s, r, 0, 0); ParseForestItem[] c1 = pf.getItems(r + 1, t, 1, 0); if (b1 != null && c1 != null) { int[][] pairs = pf.getKBestPairs(b1, c1); for (int k = 0; k < pairs.length; k++) { if (pairs[k][0] == -1 || pairs[k][1] == -1) { break; } int comp1 = pairs[k][0]; int comp2 = pairs[k][1]; double bc = b1[comp1].prob + c1[comp2].prob; double prob_fin = bc + prodProb_st; FeatureVector fv_fin = prodFV_st; if(pipe.labeled && !pipe.separateLab) { fv_fin = nt_fv_s_01.cat(nt_fv_t_00.cat(fv_fin)); prob_fin += nt_prob_s_01 + nt_prob_t_00; } pf.add(s, r, t, type1, 0, 1, prob_fin, fv_fin, b1[comp1], c1[comp2]); prob_fin = bc + prodProb_ts; fv_fin = prodFV_ts; if(pipe.labeled && !pipe.separateLab) { fv_fin = nt_fv_t_11.cat(nt_fv_s_10.cat(fv_fin)); prob_fin += nt_prob_t_11 + nt_prob_s_10; } pf.add(s, r, t, type2, 1, 1, prob_fin, fv_fin, b1[comp1], c1[comp2]); } } } } for (int r = s; r <= t; r++) { if (r != s) { ParseForestItem[] b1 = pf.getItems(s, r, 0, 1); ParseForestItem[] c1 = pf.getItems(r, t, 0, 0); if (b1 != null && c1 != null) { int[][] pairs = pf.getKBestPairs(b1, c1); for (int k = 0; k < pairs.length; k++) { if (pairs[k][0] == -1 || pairs[k][1] == -1) { break; } int comp1 = pairs[k][0]; int comp2 = pairs[k][1]; double bc = b1[comp1].prob + c1[comp2].prob; if (!pf.add(s, r, t, -1, 0, 0, bc, new FeatureVector(), b1[comp1], c1[comp2])) { break; } } } } if (r != t) { ParseForestItem[] b1 = pf.getItems(s, r, 1, 0); ParseForestItem[] c1 = pf.getItems(r, t, 1, 1); if (b1 != null && c1 != null) { int[][] pairs = pf.getKBestPairs(b1, c1); for (int k = 0; k < pairs.length; k++) { if (pairs[k][0] == -1 || pairs[k][1] == -1) { break; } int comp1 = pairs[k][0]; int comp2 = pairs[k][1]; double bc = b1[comp1].prob + c1[comp2].prob; if (!pf.add(s, r, t, -1, 1, 0, bc, new FeatureVector(), b1[comp1], c1[comp2])) { break; } } } } } } } return pf.getBestParses(); } public Object[][] decodeNonProjective(DependencyInstance inst, FeatureVector[][][] fvs, double[][][] probs, FeatureVector[][][][] nt_fvs, double[][][][] nt_probs, int K) { String[] pos = inst.postags; int numWords = inst.length(); int[][] oldI = new int[numWords][numWords]; int[][] oldO = new int[numWords][numWords]; double[][] scoreMatrix = new double[numWords][numWords]; double[][] orig_scoreMatrix = new double[numWords][numWords]; boolean[] curr_nodes = new boolean[numWords]; TIntIntHashMap[] reps = new TIntIntHashMap[numWords]; int[][] static_types = null; if (pipe.labeled) { if (!pipe.separateLab) { // afm 06-03-08 static_types = getTypes(nt_probs, pos.length); } } for (int i = 0; i < numWords; i++) { curr_nodes[i] = true; reps[i] = new TIntIntHashMap(); reps[i].put(i, 0); for (int j = 0; j < numWords; j++) { // score of edge (i,j) i --> j scoreMatrix[i][j] = probs[i < j ? i : j][i < j ? j : i][i < j ? 0 : 1] + ((pipe.labeled && !pipe.separateLab) ? nt_probs[i][static_types[i][j]][i < j ? 0 : 1][1] + nt_probs[j][static_types[i][j]][i < j ? 0 : 1][0] : 0.0); orig_scoreMatrix[i][j] = probs[i < j ? i : j][i < j ? j : i][i < j ? 0 : 1] + ((pipe.labeled && !pipe.separateLab) ? nt_probs[i][static_types[i][j]][i < j ? 0 : 1][1] + nt_probs[j][static_types[i][j]][i < j ? 0 : 1][0] : 0.0); oldI[i][j] = i; oldO[i][j] = j; if (i == j || j == 0) { continue; // no self loops of i --> 0 } } } TIntIntHashMap final_edges = chuLiuEdmonds(scoreMatrix, curr_nodes, oldI, oldO, false, new TIntIntHashMap(), reps); int[] par = new int[numWords]; int[] ns = final_edges.keys(); for (int i = 0; i < ns.length; i++) { int ch = ns[i]; int pr = final_edges.get(ns[i]); par[ch] = pr; } int[] n_par = getKChanges(par, orig_scoreMatrix, Math.min(K, par.length)); int new_k = 1; for (int i = 0; i < n_par.length; i++) { if (n_par[i] > -1) { new_k++; } } // Create Feature Vectors; int[][] fin_par = new int[new_k][numWords]; FeatureVector[][] fin_fv = new FeatureVector[new_k][numWords]; fin_par[0] = par; int c = 1; for (int i = 0; i < n_par.length; i++) { if (n_par[i] > -1) { int[] t_par = new int[par.length]; System.arraycopy(par, 0, t_par, 0, t_par.length); t_par[i] = n_par[i]; fin_par[c] = t_par; c++; } } for (int k = 0; k < fin_par.length; k++) { for (int i = 0; i < fin_par[k].length; i++) { int ch = i; int pr = fin_par[k][i]; if (pr != -1) { fin_fv[k][ch] = fvs[ch < pr ? ch : pr][ch < pr ? pr : ch][ch < pr ? 1 : 0]; if (pipe.labeled) { fin_fv[k][ch] = fin_fv[k][ch].cat(nt_fvs[ch][static_types[pr][ch]][ch < pr ? 1 : 0][0]); fin_fv[k][ch] = fin_fv[k][ch].cat(nt_fvs[pr][static_types[pr][ch]][ch < pr ? 1 : 0][1]); } } else { fin_fv[k][ch] = new FeatureVector(); } } } FeatureVector[] fin = new FeatureVector[new_k]; String[] result = new String[new_k]; for (int k = 0; k < fin.length; k++) { fin[k] = new FeatureVector(); for (int i = 1; i < fin_fv[k].length; i++) { fin[k] = fin_fv[k][i].cat(fin[k]); } result[k] = ""; for (int i = 1; i < par.length; i++) { result[k] += fin_par[k][i] + "|" + i + ((pipe.labeled && !pipe.separateLab) ? ":" + static_types[fin_par[k][i]][i] : ":0") + " "; } } // create d. Object[][] d = new Object[new_k][2]; for (int k = 0; k < new_k; k++) { d[k][0] = fin[k]; d[k][1] = result[k].trim(); } return d; } private int[] getKChanges(int[] par, double[][] scoreMatrix, int K) { int[] result = new int[par.length]; int[] n_par = new int[par.length]; double[] n_score = new double[par.length]; for (int i = 0; i < par.length; i++) { result[i] = -1; n_par[i] = -1; n_score[i] = Double.NEGATIVE_INFINITY; } boolean[][] isChild = calcChilds(par); for (int i = 1; i < n_par.length; i++) { double max = Double.NEGATIVE_INFINITY; int wh = -1; for (int j = 0; j < n_par.length; j++) { if (i == j || par[i] == j || isChild[i][j]) { continue; } if (scoreMatrix[j][i] > max) { max = scoreMatrix[j][i]; wh = j; } } n_par[i] = wh; n_score[i] = max; } for (int k = 0; k < K; k++) { double max = Double.NEGATIVE_INFINITY; int wh = -1; int whI = -1; for (int i = 0; i < n_par.length; i++) { if (n_par[i] == -1) { continue; } double score = scoreMatrix[n_par[i]][i]; if (score > max) { max = score; whI = i; wh = n_par[i]; } } if (max == Double.NEGATIVE_INFINITY) { break; } result[whI] = wh; n_par[whI] = -1; } return result; } private boolean[][] calcChilds(int[] par) { boolean[][] isChild = new boolean[par.length][par.length]; for (int i = 1; i < par.length; i++) { int l = par[i]; while (l != -1) { isChild[l][i] = true; l = par[l]; } } return isChild; } private static TIntIntHashMap chuLiuEdmonds(double[][] scoreMatrix, boolean[] curr_nodes, int[][] oldI, int[][] oldO, boolean print, TIntIntHashMap final_edges, TIntIntHashMap[] reps) { // need to construct for each node list of nodes they represent (here only!) int[] par = new int[curr_nodes.length]; int numWords = curr_nodes.length; // create best graph par[0] = -1; for (int i = 1; i < par.length; i++) { // only interested in current nodes if (!curr_nodes[i]) { continue; } double maxScore = scoreMatrix[0][i]; par[i] = 0; for (int j = 0; j < par.length; j++) { if (j == i) { continue; } if (!curr_nodes[j]) { continue; } double newScore = scoreMatrix[j][i]; if (newScore > maxScore) { maxScore = newScore; par[i] = j; } } } if (print) { DependencyParser.out.println("After init"); for (int i = 0; i < par.length; i++) { if (curr_nodes[i]) { DependencyParser.out.print(par[i] + "|" + i + " "); } } DependencyParser.out.println(); } //Find a cycle ArrayList cycles = new ArrayList(); boolean[] added = new boolean[numWords]; for (int i = 0; i < numWords && cycles.isEmpty(); i++) { // if I have already considered this or // This is not a valid node (i.e. has been contracted) if (added[i] || !curr_nodes[i]) { continue; } added[i] = true; TIntIntHashMap cycle = new TIntIntHashMap(); cycle.put(i, 0); int l = i; while (true) { if (par[l] == -1) { added[l] = true; break; } if (cycle.contains(par[l])) { cycle = new TIntIntHashMap(); int lorg = par[l]; cycle.put(lorg, par[lorg]); added[lorg] = true; int l1 = par[lorg]; while (l1 != lorg) { cycle.put(l1, par[l1]); added[l1] = true; l1 = par[l1]; } cycles.add(cycle); break; } cycle.put(l, 0); l = par[l]; if (added[l] && !cycle.contains(l)) { break; } added[l] = true; } } // get all edges and return them if (cycles.isEmpty()) { //DependencyParser.out.println("TREE:"); for (int i = 0; i < par.length; i++) { if (!curr_nodes[i]) { continue; } if (par[i] != -1) { int pr = oldI[par[i]][i]; int ch = oldO[par[i]][i]; final_edges.put(ch, pr); //DependencyParser.out.print(pr+"|"+ch + " "); } else { final_edges.put(0, -1); } } //DependencyParser.out.println(); return final_edges; } int max_cyc = 0; int wh_cyc = 0; for (int i = 0; i < cycles.size(); i++) { TIntIntHashMap cycle = (TIntIntHashMap) cycles.get(i); if (cycle.size() > max_cyc) { max_cyc = cycle.size(); wh_cyc = i; } } TIntIntHashMap cycle = (TIntIntHashMap) cycles.get(wh_cyc); int[] cyc_nodes = cycle.keys(); int rep = cyc_nodes[0]; if (print) { DependencyParser.out.println("Found Cycle"); for (int i = 0; i < cyc_nodes.length; i++) { DependencyParser.out.print(cyc_nodes[i] + " "); } DependencyParser.out.println(); } double cyc_weight = 0.0; for (int j = 0; j < cyc_nodes.length; j++) { cyc_weight += scoreMatrix[par[cyc_nodes[j]]][cyc_nodes[j]]; } for (int i = 0; i < numWords; i++) { if (!curr_nodes[i] || cycle.contains(i)) { continue; } double max1 = Double.NEGATIVE_INFINITY; int wh1 = -1; double max2 = Double.NEGATIVE_INFINITY; int wh2 = -1; for (int j = 0; j < cyc_nodes.length; j++) { int j1 = cyc_nodes[j]; if (scoreMatrix[j1][i] > max1) { max1 = scoreMatrix[j1][i]; wh1 = j1;//oldI[j1][i]; } // cycle weight + new edge - removal of old double scr = cyc_weight + scoreMatrix[i][j1] - scoreMatrix[par[j1]][j1]; if (scr > max2) { max2 = scr; wh2 = j1;//oldO[i][j1]; } } scoreMatrix[rep][i] = max1; oldI[rep][i] = oldI[wh1][i];//wh1; oldO[rep][i] = oldO[wh1][i];//oldO[wh1][i]; scoreMatrix[i][rep] = max2; oldO[i][rep] = oldO[i][wh2];//wh2; oldI[i][rep] = oldI[i][wh2];//oldI[i][wh2]; } TIntIntHashMap[] rep_cons = new TIntIntHashMap[cyc_nodes.length]; for (int i = 0; i < cyc_nodes.length; i++) { rep_cons[i] = new TIntIntHashMap(); int[] keys = reps[cyc_nodes[i]].keys(); Arrays.sort(keys); if (print) { DependencyParser.out.print(cyc_nodes[i] + ": "); } for (int j = 0; j < keys.length; j++) { rep_cons[i].put(keys[j], 0); if (print) { DependencyParser.out.print(keys[j] + " "); } } if (print) { DependencyParser.out.println(); } } // don't consider not representative nodes // these nodes have been folded for (int i = 1; i < cyc_nodes.length; i++) { curr_nodes[cyc_nodes[i]] = false; int[] keys = reps[cyc_nodes[i]].keys(); for (int j = 0; j < keys.length; j++) { reps[rep].put(keys[j], 0); } } chuLiuEdmonds(scoreMatrix, curr_nodes, oldI, oldO, print, final_edges, reps); // check each node in cycle, if one of its representatives // is a key in the final_edges, it is the one. int wh = -1; boolean found = false; for (int i = 0; i < rep_cons.length && !found; i++) { int[] keys = rep_cons[i].keys(); for (int j = 0; j < keys.length && !found; j++) { if (final_edges.contains(keys[j])) { wh = cyc_nodes[i]; found = true; } } } int l = par[wh]; while (l != wh) { int ch = oldO[par[l]][l]; int pr = oldI[par[l]][l]; final_edges.put(ch, pr); l = par[l]; } if (print) { int[] keys = final_edges.keys(); Arrays.sort(keys); for (int i = 0; i < keys.length; i++) { DependencyParser.out.print(final_edges.get(keys[i]) + "|" + keys[i] + " "); } DependencyParser.out.println(); } return final_edges; } }