package edu.stanford.nlp.classify; import edu.stanford.nlp.math.ArrayMath; import java.io.Serializable; /** * A Prior for functions. Immutable. * * @author Galen Andrew */ public class LogPrior implements Serializable { private static final long serialVersionUID = 7826853908892790965L; public enum LogPriorType { NULL, QUADRATIC, HUBER, QUARTIC, COSH, ADAPT, MULTIPLE_QUADRATIC } public static LogPriorType getType(String name) { if (name.equalsIgnoreCase("null")) { return LogPriorType.NULL; } else if (name.equalsIgnoreCase("quadratic")) { return LogPriorType.QUADRATIC; } else if (name.equalsIgnoreCase("huber")) { return LogPriorType.HUBER; } else if (name.equalsIgnoreCase("quartic")) { return LogPriorType.QUARTIC; } else if (name.equalsIgnoreCase("cosh")) { return LogPriorType.COSH; } // else if (name.equalsIgnoreCase("multiple")) { return LogPriorType.MULTIPLE; } else { throw new RuntimeException("Unknown LogPriorType: " + name); } } // these fields are just for the ADAPT prior - // is there a better way to do this? private double[] means = null; private LogPrior otherPrior = null; public static LogPrior getAdaptationPrior(double[] means, LogPrior otherPrior) { LogPrior lp = new LogPrior(LogPriorType.ADAPT); lp.means = means; lp.otherPrior = otherPrior; return lp; } public LogPriorType getType() { return type; } private final LogPriorType type; public LogPrior() { this(LogPriorType.QUADRATIC); } public LogPrior(int intPrior) { this(intPrior, 1.0, 0.1); } public LogPrior(LogPriorType type) { this(type, 1.0, 0.1); } // why isn't this functionality in enum? private static LogPriorType intToType(int intPrior) { LogPriorType[] values = LogPriorType.values(); for (LogPriorType val : values) { if (val.ordinal() == intPrior) { return val; } } throw new IllegalArgumentException(intPrior + " is not a legal LogPrior."); } public LogPrior(int intPrior, double sigma, double epsilon) { this(intToType(intPrior), sigma, epsilon); } public LogPrior(LogPriorType type, double sigma, double epsilon) { this.type = type; if (type != LogPriorType.ADAPT) { setSigma(sigma); setEpsilon(epsilon); } } // this is the C variable in CSFoo's MM paper C = 1/\sigma^2 // private double[] regularizationHyperparameters = null; private double[] sigmaSqM = null; private double[] sigmaQuM = null; // public double[] getRegularizationHyperparameters() { // return regularizationHyperparameters; // } // // public void setRegularizationHyperparameters( // double[] regularizationHyperparameters) { // this.regularizationHyperparameters = regularizationHyperparameters; // } /** * IMPORTANT NOTE: This constructor allows non-uniform regularization, but it * transforms the inputs C (like the machine learning people like) to sigma * (like we NLP folks like). C = 1/\sigma^2 */ public LogPrior(double[] C) { this.type = LogPriorType.MULTIPLE_QUADRATIC; double[] sigmaSqM = new double[C.length]; for (int i=0;i<C.length;i++){ sigmaSqM[i] = 1./C[i]; } this.sigmaSqM = sigmaSqM; setSigmaSquaredM(sigmaSqM); // this.regularizationHyperparameters = regularizationHyperparameters; } // private double sigma; private double sigmaSq; private double sigmaQu; private double epsilon; public double getSigma() { if (type == LogPriorType.ADAPT) { return otherPrior.getSigma(); } else { return Math.sqrt(sigmaSq); } } public double getSigmaSquared() { if (type == LogPriorType.ADAPT) { return otherPrior.getSigmaSquared(); } else { return sigmaSq; } } public double[] getSigmaSquaredM() { if (type == LogPriorType.MULTIPLE_QUADRATIC) { return sigmaSqM; } else { throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this); } } public double getEpsilon() { if (type == LogPriorType.ADAPT) { return otherPrior.getEpsilon(); } else { return epsilon; } } public void setSigma(double sigma) { if (type == LogPriorType.ADAPT) { otherPrior.setSigma(sigma); } else { // this.sigma = sigma; this.sigmaSq = sigma * sigma; this.sigmaQu = sigmaSq * sigmaSq; } } // public void setSigmaM(double[] sigmaM) { // if (type == LogPriorType.MULTIPLE_QUADRATIC) { // // this.sigma = Math.sqrt(sigmaSq); // double[] sigmaSqM = new double[sigmaM.length]; // double[] sigmaQuM = new double[sigmaM.length]; // // for (int i = 0;i<sigmaM.length;i++){ // sigmaSqM[i] = sigmaM[i] * sigmaM[i]; // } // this.sigmaSqM = sigmaSqM; // // for (int i = 0;i<sigmaSqM.length;i++){ // sigmaQuM[i] = sigmaSqM[i] * sigmaSqM[i]; // } // this.sigmaQuM = sigmaQuM; // // } else { // throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this); // } // } public void setSigmaSquared(double sigmaSq) { if (type == LogPriorType.ADAPT) { otherPrior.setSigmaSquared(sigmaSq); } else { // this.sigma = Math.sqrt(sigmaSq); this.sigmaSq = sigmaSq; this.sigmaQu = sigmaSq * sigmaSq; } } public void setSigmaSquaredM(double[] sigmaSq) { if (type == LogPriorType.ADAPT) { otherPrior.setSigmaSquaredM(sigmaSq); } if (type == LogPriorType.MULTIPLE_QUADRATIC) { // this.sigma = Math.sqrt(sigmaSq); this.sigmaSqM = sigmaSq.clone(); double[] sigmaQuM = new double[sigmaSq.length]; for (int i = 0;i<sigmaSq.length;i++){ sigmaQuM[i] = sigmaSqM[i] * sigmaSqM[i]; } this.sigmaQuM = sigmaQuM; } else { throw new RuntimeException("LogPrior.getSigmaSquaredM is undefined for any prior but MULTIPLE_QUADRATIC" + this); } } public void setEpsilon(double epsilon) { if (type == LogPriorType.ADAPT) { otherPrior.setEpsilon(epsilon); } else { this.epsilon = epsilon; } } public double computeStochastic(double[] x, double[] grad, double fractionOfData) { if (type == LogPriorType.ADAPT) { double[] newX = ArrayMath.pairwiseSubtract(x, means); return otherPrior.computeStochastic(newX, grad, fractionOfData); } else if (type == LogPriorType.MULTIPLE_QUADRATIC) { double[] sigmaSquaredOld = getSigmaSquaredM(); double[] sigmaSquaredTemp = sigmaSquaredOld.clone(); for (int i = 0; i < x.length; i++) { sigmaSquaredTemp[i] /= fractionOfData; } setSigmaSquaredM(sigmaSquaredTemp); double val = compute(x, grad); setSigmaSquaredM(sigmaSquaredOld); return val; } else { double sigmaSquaredOld = getSigmaSquared(); setSigmaSquared(sigmaSquaredOld / fractionOfData); double val = compute(x, grad); setSigmaSquared(sigmaSquaredOld); return val; } } /** * Adjust the given grad array by adding the prior's gradient component * and return the value of the logPrior * @param x the input point * @param grad the gradient array * @return the value */ public double compute(double[] x, double[] grad) { double val = 0.0; switch (type) { case NULL: return val; case QUADRATIC: for (int i = 0; i < x.length; i++) { val += x[i] * x[i] / 2.0 / sigmaSq; grad[i] += x[i] / sigmaSq; } return val; case HUBER: // P.J. Huber. 1973. Robust regression: Asymptotics, conjectures and // Monte Carlo. The Annals of Statistics 1: 799-821. // See also: // P. J. Huber. Robust Statistics. John Wiley & Sons, New York, 1981. for (int i = 0; i < x.length; i++) { if (x[i] < -epsilon) { val += (-x[i] - epsilon / 2.0) / sigmaSq; grad[i] += -1.0 / sigmaSq; } else if (x[i] < epsilon) { val += x[i] * x[i] / 2.0 / epsilon / sigmaSq; grad[i] += x[i] / epsilon / sigmaSq; } else { val += (x[i] - epsilon / 2.0) / sigmaSq; grad[i] += 1.0 / sigmaSq; } } return val; case QUARTIC: for (int i = 0; i < x.length; i++) { val += (x[i] * x[i]) * (x[i] * x[i]) / 2.0 / sigmaQu; grad[i] += x[i] / sigmaQu; } return val; case ADAPT: double[] newX = ArrayMath.pairwiseSubtract(x, means); val += otherPrior.compute(newX, grad); return val; case COSH: double norm = ArrayMath.norm_1(x) / sigmaSq; double d; if (norm > 30.0) { val = norm - Math.log(2); d = 1.0 / sigmaSq; } else { val = Math.log(Math.cosh(norm)); d = (2 * (1 / (Math.exp(-2.0 * norm) + 1)) - 1.0) / sigmaSq; } for (int i=0; i < x.length; i++) { grad[i] += Math.signum(x[i]) * d; } return val; case MULTIPLE_QUADRATIC: // for (int i = 0; i < x.length; i++) { // val += x[i] * x[i]* 1/2 * regularizationHyperparameters[i]; // grad[i] += x[i] * regularizationHyperparameters[i]; // } for (int i = 0; i < x.length; i++) { val += x[i] * x[i] / 2.0 / sigmaSqM[i]; grad[i] += x[i] / sigmaSqM[i]; } return val; default: throw new RuntimeException("LogPrior.valueAt is undefined for prior of type " + this); } } }