package mstparser;
public class Parameters {
private double SCORE = 0.0;
public double[] parameters;
public double[] total;
public String lossType = "punc";
public Parameters(int size) {
parameters = new double[size];
total = new double[size];
for (int i = 0; i < parameters.length; i++) {
parameters[i] = 0.0;
total[i] = 0.0;
}
lossType = "punc";
}
public Parameters(double[] parameters) {
this.parameters = parameters;
}
public void setLoss(String lt) {
lossType = lt;
}
public void averageParams(double avVal) {
for (int j = 0; j < total.length; j++) {
total[j] *= 1.0 / ((double) avVal);
}
parameters = total;
}
public void updateParamsMIRA(DependencyInstance inst, Object[][] d, double upd) {
String actParseTree = inst.actParseTree;
FeatureVector actFV = inst.fv;
int K = 0;
for (int i = 0; i < d.length && d[i][0] != null; i++) {
K = i + 1;
}
double[] b = new double[K];
double[] lam_dist = new double[K];
FeatureVector[] dist = new FeatureVector[K];
for (int k = 0; k < K; k++) {
lam_dist[k] = getScore(actFV)
- getScore((FeatureVector) d[k][0]);
b[k] = (double) numErrors(inst, (String) d[k][1], actParseTree);
b[k] -= lam_dist[k];
dist[k] = actFV.getDistVector((FeatureVector) d[k][0]);
}
double[] alpha = hildreth(dist, b);
FeatureVector fv;
int res = 0;
for (int k = 0; k < K; k++) {
fv = dist[k];
fv.update(parameters, total, alpha[k], upd);
//for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next) {
// if(curr.index < 0)
// continue;
// parameters[curr.index] += alpha[k]*curr.value;
// total[curr.index] += upd*alpha[k]*curr.value;
//}
}
}
public double getScore(FeatureVector fv) {
return fv.getScore(parameters);
//double score = 0.0;
//for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next)
// score += parameters[curr.index]*curr.value;
//return score;
}
private double[] hildreth(FeatureVector[] a, double[] b) {
int i;
int max_iter = 10000;
double eps = 0.00000001;
double zero = 0.000000000001;
double[] alpha = new double[b.length];
double[] F = new double[b.length];
double[] kkt = new double[b.length];
double max_kkt = Double.NEGATIVE_INFINITY;
int K = a.length;
double[][] A = new double[K][K];
boolean[] is_computed = new boolean[K];
for (i = 0; i < K; i++) {
A[i][i] = a[i].dotProduct(a[i]);
is_computed[i] = false;
}
int max_kkt_i = -1;
for (i = 0; i < F.length; i++) {
F[i] = b[i];
kkt[i] = F[i];
if (kkt[i] > max_kkt) {
max_kkt = kkt[i];
max_kkt_i = i;
}
}
int iter = 0;
double diff_alpha;
double try_alpha;
double add_alpha;
while (max_kkt >= eps && iter < max_iter) {
diff_alpha = A[max_kkt_i][max_kkt_i] <= zero ? 0.0 : F[max_kkt_i] / A[max_kkt_i][max_kkt_i];
try_alpha = alpha[max_kkt_i] + diff_alpha;
if (try_alpha < 0.0) {
add_alpha = -1.0 * alpha[max_kkt_i];
} else {
add_alpha = diff_alpha;
}
alpha[max_kkt_i] = alpha[max_kkt_i] + add_alpha;
if (!is_computed[max_kkt_i]) {
for (i = 0; i < K; i++) {
A[i][max_kkt_i] = a[i].dotProduct(a[max_kkt_i]); // for version 1
is_computed[max_kkt_i] = true;
}
}
for (i = 0; i < F.length; i++) {
F[i] -= add_alpha * A[i][max_kkt_i];
kkt[i] = F[i];
if (alpha[i] > zero) {
kkt[i] = Math.abs(F[i]);
}
}
max_kkt = Double.NEGATIVE_INFINITY;
max_kkt_i = -1;
for (i = 0; i < F.length; i++) {
if (kkt[i] > max_kkt) {
max_kkt = kkt[i];
max_kkt_i = i;
}
}
iter++;
}
return alpha;
}
public double numErrors(DependencyInstance inst, String pred, String act) {
if (lossType.equals("nopunc")) {
return numErrorsDepNoPunc(inst, pred, act) + numErrorsLabelNoPunc(inst, pred, act);
}
return numErrorsDep(inst, pred, act) + numErrorsLabel(inst, pred, act);
}
public double numErrorsDep(DependencyInstance inst, String pred, String act) {
String[] act_spans = act.split(" ");
String[] pred_spans = pred.split(" ");
int correct = 0;
for (int i = 0; i < pred_spans.length; i++) {
String p = pred_spans[i].split(":")[0];
String a = act_spans[i].split(":")[0];
if (p.equals(a)) {
correct++;
}
}
return ((double) act_spans.length - correct);
}
public double numErrorsLabel(DependencyInstance inst, String pred, String act) {
String[] act_spans = act.split(" ");
String[] pred_spans = pred.split(" ");
int correct = 0;
for (int i = 0; i < pred_spans.length; i++) {
String p = pred_spans[i].split(":")[1];
String a = act_spans[i].split(":")[1];
if (p.equals(a)) {
correct++;
}
}
return ((double) act_spans.length - correct);
}
public double numErrorsDepNoPunc(DependencyInstance inst, String pred, String act) {
String[] act_spans = act.split(" ");
String[] pred_spans = pred.split(" ");
String[] pos = inst.postags;
int correct = 0;
int numPunc = 0;
for (int i = 0; i < pred_spans.length; i++) {
String p = pred_spans[i].split(":")[0];
String a = act_spans[i].split(":")[0];
if (pos[i + 1].matches("[,:.'`]+")) {
numPunc++;
continue;
}
if (p.equals(a)) {
correct++;
}
}
return ((double) act_spans.length - numPunc - correct);
}
public double numErrorsLabelNoPunc(DependencyInstance inst, String pred, String act) {
String[] act_spans = act.split(" ");
String[] pred_spans = pred.split(" ");
String[] pos = inst.postags;
int correct = 0;
int numPunc = 0;
for (int i = 0; i < pred_spans.length; i++) {
String p = pred_spans[i].split(":")[1];
String a = act_spans[i].split(":")[1];
if (pos[i + 1].matches("[,:.'`]+")) {
numPunc++;
continue;
}
if (p.equals(a)) {
correct++;
}
}
return ((double) act_spans.length - numPunc - correct);
}
}