package mstparser; public class Parameters { private double SCORE = 0.0; public double[] parameters; public double[] total; public String lossType = "punc"; public Parameters(int size) { parameters = new double[size]; total = new double[size]; for (int i = 0; i < parameters.length; i++) { parameters[i] = 0.0; total[i] = 0.0; } lossType = "punc"; } public Parameters(double[] parameters) { this.parameters = parameters; } public void setLoss(String lt) { lossType = lt; } public void averageParams(double avVal) { for (int j = 0; j < total.length; j++) { total[j] *= 1.0 / ((double) avVal); } parameters = total; } public void updateParamsMIRA(DependencyInstance inst, Object[][] d, double upd) { String actParseTree = inst.actParseTree; FeatureVector actFV = inst.fv; int K = 0; for (int i = 0; i < d.length && d[i][0] != null; i++) { K = i + 1; } double[] b = new double[K]; double[] lam_dist = new double[K]; FeatureVector[] dist = new FeatureVector[K]; for (int k = 0; k < K; k++) { lam_dist[k] = getScore(actFV) - getScore((FeatureVector) d[k][0]); b[k] = (double) numErrors(inst, (String) d[k][1], actParseTree); b[k] -= lam_dist[k]; dist[k] = actFV.getDistVector((FeatureVector) d[k][0]); } double[] alpha = hildreth(dist, b); FeatureVector fv; int res = 0; for (int k = 0; k < K; k++) { fv = dist[k]; fv.update(parameters, total, alpha[k], upd); //for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next) { // if(curr.index < 0) // continue; // parameters[curr.index] += alpha[k]*curr.value; // total[curr.index] += upd*alpha[k]*curr.value; //} } } public double getScore(FeatureVector fv) { return fv.getScore(parameters); //double score = 0.0; //for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next) // score += parameters[curr.index]*curr.value; //return score; } private double[] hildreth(FeatureVector[] a, double[] b) { int i; int max_iter = 10000; double eps = 0.00000001; double zero = 0.000000000001; double[] alpha = new double[b.length]; double[] F = new double[b.length]; double[] kkt = new double[b.length]; double max_kkt = Double.NEGATIVE_INFINITY; int K = a.length; double[][] A = new double[K][K]; boolean[] is_computed = new boolean[K]; for (i = 0; i < K; i++) { A[i][i] = a[i].dotProduct(a[i]); is_computed[i] = false; } int max_kkt_i = -1; for (i = 0; i < F.length; i++) { F[i] = b[i]; kkt[i] = F[i]; if (kkt[i] > max_kkt) { max_kkt = kkt[i]; max_kkt_i = i; } } int iter = 0; double diff_alpha; double try_alpha; double add_alpha; while (max_kkt >= eps && iter < max_iter) { diff_alpha = A[max_kkt_i][max_kkt_i] <= zero ? 0.0 : F[max_kkt_i] / A[max_kkt_i][max_kkt_i]; try_alpha = alpha[max_kkt_i] + diff_alpha; if (try_alpha < 0.0) { add_alpha = -1.0 * alpha[max_kkt_i]; } else { add_alpha = diff_alpha; } alpha[max_kkt_i] = alpha[max_kkt_i] + add_alpha; if (!is_computed[max_kkt_i]) { for (i = 0; i < K; i++) { A[i][max_kkt_i] = a[i].dotProduct(a[max_kkt_i]); // for version 1 is_computed[max_kkt_i] = true; } } for (i = 0; i < F.length; i++) { F[i] -= add_alpha * A[i][max_kkt_i]; kkt[i] = F[i]; if (alpha[i] > zero) { kkt[i] = Math.abs(F[i]); } } max_kkt = Double.NEGATIVE_INFINITY; max_kkt_i = -1; for (i = 0; i < F.length; i++) { if (kkt[i] > max_kkt) { max_kkt = kkt[i]; max_kkt_i = i; } } iter++; } return alpha; } public double numErrors(DependencyInstance inst, String pred, String act) { if (lossType.equals("nopunc")) { return numErrorsDepNoPunc(inst, pred, act) + numErrorsLabelNoPunc(inst, pred, act); } return numErrorsDep(inst, pred, act) + numErrorsLabel(inst, pred, act); } public double numErrorsDep(DependencyInstance inst, String pred, String act) { String[] act_spans = act.split(" "); String[] pred_spans = pred.split(" "); int correct = 0; for (int i = 0; i < pred_spans.length; i++) { String p = pred_spans[i].split(":")[0]; String a = act_spans[i].split(":")[0]; if (p.equals(a)) { correct++; } } return ((double) act_spans.length - correct); } public double numErrorsLabel(DependencyInstance inst, String pred, String act) { String[] act_spans = act.split(" "); String[] pred_spans = pred.split(" "); int correct = 0; for (int i = 0; i < pred_spans.length; i++) { String p = pred_spans[i].split(":")[1]; String a = act_spans[i].split(":")[1]; if (p.equals(a)) { correct++; } } return ((double) act_spans.length - correct); } public double numErrorsDepNoPunc(DependencyInstance inst, String pred, String act) { String[] act_spans = act.split(" "); String[] pred_spans = pred.split(" "); String[] pos = inst.postags; int correct = 0; int numPunc = 0; for (int i = 0; i < pred_spans.length; i++) { String p = pred_spans[i].split(":")[0]; String a = act_spans[i].split(":")[0]; if (pos[i + 1].matches("[,:.'`]+")) { numPunc++; continue; } if (p.equals(a)) { correct++; } } return ((double) act_spans.length - numPunc - correct); } public double numErrorsLabelNoPunc(DependencyInstance inst, String pred, String act) { String[] act_spans = act.split(" "); String[] pred_spans = pred.split(" "); String[] pos = inst.postags; int correct = 0; int numPunc = 0; for (int i = 0; i < pred_spans.length; i++) { String p = pred_spans[i].split(":")[1]; String a = act_spans[i].split(":")[1]; if (pos[i + 1].matches("[,:.'`]+")) { numPunc++; continue; } if (p.equals(a)) { correct++; } } return ((double) act_spans.length - numPunc - correct); } }