package edu.stanford.nlp.semparse.open.model;
import fig.basic.*;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
/**
* Params contains the parameters of the model. Currently consists of a map from
* features to weights.
*
* @author Percy Liang
*/
public class Params {
public static class Options {
@Option(gloss = "By default, all features have this weight")
public double defaultWeight = 0;
@Option(gloss = "Randomly initialize the weights")
public boolean initWeightsRandomly = false;
@Option(gloss = "Randomly initialize the weights")
public Random initRandom = new Random(1);
@Option(gloss = "Initial step size")
public double initStepSize = 1;
@Option(gloss = "How fast to reduce the step size")
public double stepSizeReduction = 0;
@Option(gloss = "Use the AdaGrad algorithm (different step size for each coordinate)")
public boolean adaptiveStepSize = true;
@Option(gloss = "Use dual averaging")
public boolean dualAveraging = false;
}
public static Options opts = new Options();
// Discriminative weights
HashMap<String, Double> weights = new HashMap<String, Double>();
public double getWeight(String f) {
if (opts.initWeightsRandomly)
return MapUtils.getDouble(weights, f, 2 * opts.initRandom.nextDouble() - 1);
else
return MapUtils.getDouble(weights, f, opts.defaultWeight);
}
// ============================================================
// Weight update
// ============================================================
// For AdaGrad
Map<String, Double> sumSquaredGradients = new HashMap<String, Double>();
// For dual averaging
Map<String, Double> sumGradients = new HashMap<String, Double>();
// Number of stochastic updates we've made so far (for determining step size).
int numUpdates;
/**
* Update weights by adding |gradient| (modified appropriately with step size).
*/
public void update(Map<String, Double> gradient) {
numUpdates++;
for (Map.Entry<String, Double> entry : gradient.entrySet()) {
String f = entry.getKey();
double g = entry.getValue();
if (Math.abs(g) < 1e-6) continue;
double stepSize;
if (opts.adaptiveStepSize) {
MapUtils.incr(sumSquaredGradients, f, g * g);
stepSize = opts.initStepSize / Math.sqrt(sumSquaredGradients.get(f));
} else {
stepSize = opts.initStepSize / Math.pow(numUpdates, opts.stepSizeReduction);
}
if (Double.isNaN(stepSize) || Double.isNaN(g)) {
LogInfo.fails("WTF? %s %s %s", f, g, sumSquaredGradients.get(f));
}
if (opts.dualAveraging) {
if (!opts.adaptiveStepSize && opts.stepSizeReduction != 0)
throw new RuntimeException("Dual averaging not supported when " +
"step-size changes across iterations for " +
"features for which the gradient is zero");
MapUtils.incr(sumGradients, f, g);
MapUtils.set(weights, f, stepSize * sumGradients.get(f));
} else {
MapUtils.incr(weights, f, stepSize * g);
}
}
}
public static double L1Cut(double x, double cutoff) {
return (x > cutoff) ? (x - cutoff) : (x < -cutoff) ? (x + cutoff) : 0;
}
/**
* Apply L1 regularization:
* - If weight > cutoff, then weight := weight - cutoff
* - If weight < -cutoff, then weight := weight + cutoff
* - Otherwise, weight := 0
* @param cutoff regularization parameter (>= 0)
*/
public void applyL1Regularization(double cutoff) {
if (cutoff <= 0) return;
for (Map.Entry<String, Double> entry : weights.entrySet()) {
entry.setValue(L1Cut(entry.getValue(), cutoff));
}
}
/**
* Prune features with small weights
* @param threshold the maximum absolute value for weights to be pruned
*/
public void prune(double threshold) {
if (threshold <= 0) return;
Iterator<Map.Entry<String, Double>> iter = weights.entrySet().iterator();
while (iter.hasNext()) {
if (Math.abs(iter.next().getValue()) < threshold) iter.remove();
}
}
// ============================================================
// Persistence
// ============================================================
/**
* Read parameters from |path|.
*/
public void read(String path) {
LogInfo.begin_track("Reading parameters from %s", path);
try {
BufferedReader in = IOUtils.openIn(path);
String line;
while ((line = in.readLine()) != null) {
String[] pair = line.split("\t");
weights.put(pair[0], Double.parseDouble(pair[1]));
}
in.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
LogInfo.logs("Read %s weights", weights.size());
LogInfo.end_track();
}
public void write(PrintWriter out) { write(null, out); }
public void write(String prefix, PrintWriter out) {
List<Map.Entry<String, Double>> entries = new ArrayList<>(weights.entrySet());
Collections.sort(entries, new ValueComparator<String, Double>(true));
for (Map.Entry<String, Double> entry : entries) {
double value = entry.getValue();
out.println((prefix == null ? "" : prefix + "\t") + entry.getKey() + "\t" + value);
}
}
public void write(String path) {
LogInfo.begin_track("Params.write(%s)", path);
PrintWriter out = IOUtils.openOutHard(path);
write(out);
out.close();
LogInfo.end_track();
}
// ============================================================
// Logging
// ============================================================
public void log() {
LogInfo.begin_track("Params");
List<Map.Entry<String, Double>> entries = new ArrayList<>(weights.entrySet());
Collections.sort(entries, new ValueComparator<String, Double>(true));
for (Map.Entry<String, Double> entry : entries) {
double value = entry.getValue();
LogInfo.logs("%s\t%s", entry.getKey(), value);
}
LogInfo.end_track();
}
}