package edu.stanford.nlp.semparse.open.model;
import java.util.Random;
import edu.stanford.nlp.semparse.open.model.candidate.Candidate;
import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.basic.LogInfo;
import fig.basic.Option;
public class AdvancedWordVectorParamsLowRank extends AdvancedWordVectorParams {
public static class Options {
@Option(gloss = "Rank of advanced word vector feature parameters (ignored when full rank)")
public int vecRank = 5;
@Option(gloss = "Randomly initialize the weights (ignored when full rank)")
public boolean vecInitWeightsRandomly = true;
@Option(gloss = "Randomly initialize the weights (ignored when full rank)")
public Random vecInitRandom = new Random(1);
}
public static Options opts = new Options();
// A = sum{u[i] v[i]^T} (i = 0, ..., rank - 1)
protected final double[][] u, v;
// Shorthands for rank and word vector dimension
protected final int rank, dim;
public AdvancedWordVectorParamsLowRank() {
rank = opts.vecRank;
dim = getDim();
u = new double[rank][dim];
v = new double[rank][dim];
if (opts.vecInitWeightsRandomly) {
for (int i = 0; i < rank; i++) {
for (int j = 0; j < dim; j++) {
u[i][j] = 2 * opts.vecInitRandom.nextDouble() - 1;
v[i][j] = 2 * opts.vecInitRandom.nextDouble() - 1;
}
}
}
initGradientStats();
}
// ============================================================
// Get score
// ============================================================
@Override
public double getScore(Candidate candidate) {
return getScore(getX(candidate), getY(candidate));
}
/**
* Return the score = x^T A y = sum{(u[i]^T x)(v[i]^T y)}
*/
public double getScore(double[] x, double[] y) {
if (x == null || y == null)
return 0;
double answer = 0;
for (int i = 0; i < rank; i++) {
answer += ListUtils.dot(u[i], x) * ListUtils.dot(v[i], y);
}
return answer;
}
// ============================================================
// Compute gradient
// ============================================================
@Override
public AdvancedWordVectorGradient createGradient() {
return new AdvancedWordVectorGradientLowRank();
}
class AdvancedWordVectorGradientLowRank implements AdvancedWordVectorGradient {
protected final double gradU[][], gradV[][];
public AdvancedWordVectorGradientLowRank() {
gradU = new double[rank][dim];
gradV = new double[rank][dim];
}
@Override
public void addToGradient(Candidate candidate, double factor) {
addToGradient(getX(candidate), getY(candidate), factor);
}
/**
* Compute the gradient for the word vector pair (x,y) and add it to the
* accumulative gradient.
*
* The gradient of (x^T A y) with respect to u[i] is (v[i]^T y) x The
* gradient of (x^T A y) with respect to v[i] is (u[i]^T x) y
*/
private void addToGradient(double[] x, double[] y, double factor) {
if (x == null || y == null)
return;
for (int i = 0; i < rank; i++) {
ListUtils.incr(gradU[i], factor * ListUtils.dot(v[i], y), x);
ListUtils.incr(gradV[i], factor * ListUtils.dot(u[i], x), y);
}
}
@Override
public void addL2Regularization(double beta) {
for (int i = 0; i < rank; i++) {
ListUtils.incr(gradU[i], -beta, u[i]);
ListUtils.incr(gradV[i], -beta, v[i]);
}
}
}
// ============================================================
// Weight update
// ============================================================
// For AdaGrad
double[][] sumSquaredGradientsU, sumSquaredGradientsV;
// For dual averaging
double[][] sumGradientsU, sumGradientsV;
protected void initGradientStats() {
if (Params.opts.adaptiveStepSize) {
sumSquaredGradientsU = new double[rank][dim];
sumSquaredGradientsV = new double[rank][dim];
}
if (Params.opts.dualAveraging) {
sumGradientsU = new double[rank][dim];
sumGradientsV = new double[rank][dim];
}
}
// Number of stochastic updates we've made so far (for determining step size).
int numUpdates;
/**
* Update u and v with the gradient
*/
@Override
public void update(AdvancedWordVectorGradient gradient) {
AdvancedWordVectorGradientLowRank g = (AdvancedWordVectorGradientLowRank) gradient;
numUpdates++;
for (int i = 0; i < rank; i++) {
if (Params.opts.adaptiveStepSize) {
ListUtils.incr(sumSquaredGradientsU[i], 1, ListUtils.sq(g.gradU[i]));
ListUtils.incr(sumSquaredGradientsV[i], 1, ListUtils.sq(g.gradV[i]));
}
if (Params.opts.dualAveraging) {
ListUtils.incr(sumGradientsU[i], 1, g.gradU[i]);
ListUtils.incr(sumGradientsV[i], 1, g.gradV[i]);
}
for (int j = 0; j < dim; j++) {
double stepSizeU, stepSizeV;
if (Params.opts.adaptiveStepSize) {
stepSizeU = Params.opts.initStepSize / Math.sqrt(sumSquaredGradientsU[i][j]);
stepSizeV = Params.opts.initStepSize / Math.sqrt(sumSquaredGradientsV[i][j]);
} else {
stepSizeU = stepSizeV = Params.opts.initStepSize / Math.pow(numUpdates, Params.opts.stepSizeReduction);
}
if (Params.opts.dualAveraging) {
u[i][j] = stepSizeU * sumGradientsU[i][j];
v[i][j] = stepSizeV * sumGradientsV[i][j];
} else {
u[i][j] += stepSizeU * g.gradU[i][j];
v[i][j] += stepSizeV * g.gradV[i][j];
}
}
}
}
/**
* Apply L1 regularization: - If weight > cutoff, then weight := weight -
* cutoff - If weight < -cutoff, then weight := weight + cutoff - Otherwise,
* weight := 0
*
* @param cutoff
* regularization parameter (>= 0)
*/
@Override
public void applyL1Regularization(double cutoff) {
if (cutoff <= 0)
return;
for (int i = 0; i < rank; i++) {
for (int j = 0; j < dim; j++) {
u[i][j] = Params.L1Cut(u[i][j], cutoff);
v[i][j] = Params.L1Cut(v[i][j], cutoff);
}
}
}
// ============================================================
// Logging
// ============================================================
@Override
public void log() {
LogInfo.begin_track("Advanced Word Vector Params");
for (int i = 0; i < rank; i++) {
LogInfo.begin_track("u[%d] and v[%d]", i, i);
for (int j = 0; j < dim; j++) {
LogInfo.logs("%4d %6s %6s", j, Fmt.D(u[i][j]), Fmt.D(v[i][j]));
}
LogInfo.end_track();
}
LogInfo.end_track();
}
@Override
public void logFeatureWeights(Candidate candidate) {
LogInfo.begin_track("Advanced Word Vector feature weights");
double[] x = getX(candidate), y = getY(candidate);
if (x == null) {
LogInfo.log("NONE: x (query word vector) is null");
} else if (y == null) {
LogInfo.log("NONE: y (entities word vector) is null");
} else {
for (int i = 0; i < rank; i++) {
LogInfo.logs("%4d: %6s", i, Fmt.D(ListUtils.dot(u[i], x) * ListUtils.dot(v[i], y)));
}
}
LogInfo.end_track();
}
@Override
public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) {
LogInfo.begin_track("Advanced Word Vector feature weights");
// The candidates should be from the same example --> assume x are the same
double[] x = getX(trueCandidate), yTrue = getY(trueCandidate), yPred = getY(predCandidate);
if (x == null) {
LogInfo.log("NONE: x (query word vector) is null");
} else if (yTrue == null) {
LogInfo.log("NONE: y (entities word vector) is null for trueCandidate");
} else if (yPred == null) {
LogInfo.log("NONE: y (entities word vector) is null for predCandidate");
} else {
for (int i = 0; i < rank; i++) {
double trueScore = ListUtils.dot(u[i], x) * ListUtils.dot(v[i], yTrue);
double predScore = ListUtils.dot(u[i], x) * ListUtils.dot(v[i], yPred);
LogInfo.logs("%4d: %6s [ %s - %s ]", i, Fmt.D(trueScore - predScore),
Fmt.D(trueScore), Fmt.D(predScore));
}
}
LogInfo.end_track();
}
}