package edu.berkeley.nlp.discPCFG;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Pair;
/**
* @author petrov
*
*/
/**
* This is the MaximumEntropy objective function: the (negative) log conditional likelihood of the training data,
* possibly with a penalty for large weights. Note that this objective get MINIMIZED so it's the negative of the
* objective we normally think of.
*/
public class ProperNameObjectiveFunction <F,L> implements ObjectiveFunction {
IndexLinearizer indexLinearizer;
Encoding<F, L> encoding;
EncodedDatum[] data;
double[] x;
double sigma;
double lastValue;
double[] lastDerivative;
double[] lastX;
boolean isUpToDate;
public void shutdown(){
}
public void updateGoldCountsNextRound(){
}
public int dimension() {
return indexLinearizer.getNumLinearIndexes();
}
public double valueAt(double[] x) {
ensureCache(x);
isUpToDate = false;
return lastValue;
}
public double[] derivativeAt(double[] x) {
ensureCache(x);
isUpToDate = false;
return lastDerivative;
}
public double[] unregularizedDerivativeAt(double[] x) {
return null;
}
private void ensureCache(double[] x) {
if (!isUpToDate){ //requiresUpdate(lastX, x)) {
this.x = x;
Pair<Double, double[]> currentValueAndDerivative = calculate();
lastValue = currentValueAndDerivative.getFirst();
lastDerivative = currentValueAndDerivative.getSecond();
lastX = x;
}
}
/*
private boolean requiresUpdate(double[] lastX, double[] x) {
if (lastX == null) return true;
for (int i = 0; i < x.length; i++) {
if (lastX[i] != x[i])
return true;
}
return false;
}*/
public void setX(double[] x){
this.x = x;
}
public void isUpToDate(boolean b){
isUpToDate = b;
}
/**
* The most important part of the classifier learning process! This method determines, for the given weight vector
* x, what the (negative) log conditional likelihood of the data is, as well as the derivatives of that likelihood
* wrt each weight parameter.
*/
public Pair<Double, double[]> calculate() {
double objective = 0.0;
System.out.println("In Calculate...");
double[] derivatives = DoubleArrays.constantArray(0.0, dimension());
int numSubLabels = encoding.getNumSubLabels();
int numData = data.length;
for (int l = 0; l < numData; ++l) {
EncodedDatum datum = data[l];
double[] logProbabilities = getLogProbabilities(datum,x,encoding,indexLinearizer);
int C = datum.getLabelIndex();
double[] labelWeights = datum.getWeights();
int numSubstatesC = labelWeights.length;
int substate0 = encoding.getLabelSubindexBegin(C);
for (int c = 0; c < numSubstatesC; c++) { // For each substate of label C
objective -= labelWeights[c] * logProbabilities[substate0 + c];
}
// Convert to probabilities:
double[] probabilities = new double[numSubLabels];
double sum = 0.0;
for (int c = 0; c < numSubLabels; ++c) { // For each substate
probabilities[c] = Math.exp(logProbabilities[c]);
sum += probabilities[c];
}
if (Math.abs(sum-1.0) > 1e-3) {
System.err.println("Probabilities do not sum to 1!");
}
// Compute derivatives:
for (int i = 0; i < datum.getNumActiveFeatures(); ++i) {
int featureIndex = datum.getFeatureIndex(i);
double featureCount = datum.getFeatureCount(i);
for (int c = 0; c < numSubLabels; ++c) { // For each substate
int index = indexLinearizer.getLinearIndex(featureIndex,c);
derivatives[index] += featureCount*probabilities[c];
}
for (int c = 0; c < numSubstatesC; c++) { // For each substate of label C
int index = indexLinearizer.getLinearIndex(featureIndex, substate0 + c);
derivatives[index] -= labelWeights[c] * featureCount;
}
}
}
// Incorporate penalty terms (regularization) into the objective and derivatives
double sigma2 = sigma*sigma;
double penalty = 0.0;
for (int index = 0; index < x.length; ++index) {
penalty += x[index]*x[index];
}
objective += penalty / (2*sigma2);
for (int index = 0; index < x.length; ++index) {
// 'x' and 'derivatives' have same layout
derivatives[index] += x[index]/sigma2;
}
return new Pair<Double, double[]>(objective, derivatives);
}
/**
* Calculate the log probabilities of each class, for the given datum (feature bundle).
*/
public <F,L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
// Compute unnormalized log probabilities
int numSubLabels = encoding.getNumSubLabels();
double[] logProbabilities = DoubleArrays.constantArray(0.0, numSubLabels);
for (int i = 0; i < datum.getNumActiveFeatures(); i++) {
int featureIndex = datum.getFeatureIndex(i);
double featureCount = datum.getFeatureCount(i);
for (int j = 0; j < numSubLabels; j++) {
int index = indexLinearizer.getLinearIndex(featureIndex,j);
double weight = weights[index];
logProbabilities[j] += weight*featureCount;
}
}
// Normalize
double logNormalizer = SloppyMath.logAdd(logProbabilities);
for (int i = 0; i < numSubLabels; i++) {
logProbabilities[i] -= logNormalizer;
}
return logProbabilities;
}
public ProperNameObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] data, IndexLinearizer indexLinearizer, double sigma) {
this.indexLinearizer = indexLinearizer;
this.encoding = encoding;
this.data = data;
this.sigma = sigma;
}
}