package edu.stanford.nlp.semparse.open.model; import edu.stanford.nlp.semparse.open.model.candidate.Candidate; import fig.basic.Fmt; import fig.basic.LogInfo; public class AdvancedWordVectorParamsFullRank extends AdvancedWordVectorParams { double[][] weights; // Shorthands for word vector dimension protected final int dim; public AdvancedWordVectorParamsFullRank() { dim = getDim(); weights = new double[dim][dim]; if (Params.opts.initWeightsRandomly) { for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { weights[i][j] = 2 * Params.opts.initRandom.nextDouble() - 1; } } } initGradientStats(); } // ============================================================ // Get score // ============================================================ @Override public double getScore(Candidate candidate) { return getScore(getX(candidate), getY(candidate)); } public double getScore(double[] x, double[] y) { if (x == null || y == null) return 0; double answer = 0; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { answer += weights[i][j] * x[i] * y[j]; } } return answer; } // ============================================================ // Compute gradient // ============================================================ @Override public AdvancedWordVectorGradient createGradient() { return new AdvancedWordVectorGradientFullRank(); } class AdvancedWordVectorGradientFullRank implements AdvancedWordVectorGradient { protected final double grad[][]; public AdvancedWordVectorGradientFullRank() { grad = new double[dim][dim]; } @Override public void addToGradient(Candidate candidate, double factor) { addToGradient(getX(candidate), getY(candidate), factor); } /** * Compute the gradient for the word vector pair (x,y) and add it to the * accumulative gradient. */ private void addToGradient(double[] x, double[] y, double factor) { if (x == null || y == null) return; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { grad[i][j] += x[i] * y[j] * factor; } } } @Override public void addL2Regularization(double beta) { for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { grad[i][j] -= beta * weights[i][j]; } } } } // ============================================================ // Weight update // ============================================================ // For AdaGrad double[][] sumSquaredGradients; // For dual averaging double[][] sumGradients; protected void initGradientStats() { if (Params.opts.adaptiveStepSize) sumSquaredGradients = new double[dim][dim]; if (Params.opts.dualAveraging) sumGradients = new double[dim][dim]; } // Number of stochastic updates we've made so far (for determining step size). int numUpdates; @Override public void update(AdvancedWordVectorGradient gradient) { AdvancedWordVectorGradientFullRank grad = (AdvancedWordVectorGradientFullRank) gradient; numUpdates++; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { double g = grad.grad[i][j]; if (Math.abs(g) < 1e-6) continue; double stepSize; if (Params.opts.adaptiveStepSize) { sumSquaredGradients[i][j] += g * g; stepSize = Params.opts.initStepSize / Math.sqrt(sumSquaredGradients[i][j]); } else { stepSize = Params.opts.initStepSize / Math.pow(numUpdates, Params.opts.stepSizeReduction); } if (Params.opts.dualAveraging) { sumGradients[i][j] += g; weights[i][j] = stepSize * sumGradients[i][j]; } else { weights[i][j] += stepSize * g; } } } } @Override public void applyL1Regularization(double cutoff) { if (cutoff <= 0) return; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { weights[i][j] = Params.L1Cut(weights[i][j], cutoff); } } } // ============================================================ // Logging // ============================================================ @Override public void log() { LogInfo.begin_track("Advanced Word Vector Params"); for (int i = 0; i < dim; i++) { StringBuilder sb = new StringBuilder(); for (int j = 0; j < dim; j++) { sb.append(String.format("%10s ", Fmt.D(weights[i][j]))); } LogInfo.log(sb.toString()); } LogInfo.end_track(); } @Override public void logFeatureWeights(Candidate candidate) { LogInfo.begin_track("Advanced Word Vector feature weights"); double[] x = getX(candidate), y = getY(candidate); if (x == null) { LogInfo.log("NONE: x (query word vector) is null"); } else if (y == null) { LogInfo.log("NONE: y (entities word vector) is null"); } else { LogInfo.logs("Advanced Word Vector: %s", Fmt.D(getScore(x, y))); } LogInfo.end_track(); } @Override public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) { LogInfo.begin_track("Advanced Word Vector feature weights"); // The candidates should be from the same example --> assume x are the same double[] x = getX(trueCandidate), yTrue = getY(trueCandidate), yPred = getY(predCandidate); if (x == null) { LogInfo.log("NONE: x (query word vector) is null"); } else if (yTrue == null) { LogInfo.log("NONE: y (entities word vector) is null for trueCandidate"); } else if (yPred == null) { LogInfo.log("NONE: y (entities word vector) is null for predCandidate"); } else { double trueScore = getScore(x, yTrue), predScore = getScore(x, yPred); LogInfo.logs("Advanced Word Vector: %s [ %s - %s ]", Fmt.D(trueScore - predScore), Fmt.D(trueScore), Fmt.D(predScore)); } LogInfo.end_track(); } }