/**
*
*/
package edu.berkeley.nlp.discPCFG;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.ConstrainedHierarchicalTwoChartParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
/**
* @author petrov
*
*/
public class ParsingObjectiveFunction implements ObjectiveFunction {
public static final int NO_REGULARIZATION = 0;
public static final int L1_REGULARIZATION = 1;
public static final int L2_REGULARIZATION = 2;
Grammar grammar;
SimpleLexicon lexicon;
SpanPredictor spanPredictor;
Linearizer linearizer;
int myRegularization;
double sigma;
double lastValue;
double[] lastDerivative;
double[] lastUnregularizedDerivative;
double[] x;
int dimension;
int nGrammarWeights, nLexiconWeights, nSpanWeights;
int nProcesses;
String consBaseName;
StateSetTreeList[] trainingTrees;
ExecutorService pool;
Calculator[] tasks;
double bestObjectiveSoFar;
String outFileName;
double[] spanGoldCounts;
public int dimension() {
return dimension;
}
public double valueAt(double[] x) {
ensureCache(x);
return lastValue;
}
public double[] derivativeAt(double[] x) {
ensureCache(x);
return lastDerivative;
}
public double[] unregularizedDerivativeAt(double[] x) {
ensureCache(x);
return lastUnregularizedDerivative;
}
private void ensureCache(double[] proposed_x) {
if (requiresUpdate(proposed_x)){
linearizer.delinearizeWeights(proposed_x);
grammar = linearizer.getGrammar();
lexicon = linearizer.getLexicon();
spanPredictor = linearizer.getSpanPredictor();
if (this.x == null) this.x = proposed_x.clone();
else{
for (int xi=0; xi<x.length; xi++){
this.x[xi] = proposed_x[xi];
}
}
System.out.print("Task: ");
Future[] submits = new Future[nProcesses];
// pool =
// Executors.newCachedThreadPool();//newSingleThreadExecutor();//newFixedThreadPool(nProcesses);
if (nProcesses > 1) {
for (int i = 0; i < nProcesses; i++) {
Future submit = pool.submit(tasks[i]);// execute(tasks[i]);
submits[i] = submit;
}
while (true) {
boolean done = true;
for (Future task : submits) {
done &= task.isDone();
}
if (done)
break;
}
}
// accumulate
double objective = 0;
int nUnparasble = 0, nIncorrectLL = 0;
double[] derivatives = new double[dimension];
for (int i = 0; i < nProcesses; i++) {
Counts counts = null;
if (nProcesses == 1) {
counts = tasks[0].call();
} else {
try {
counts = (Counts) submits[i].get();
} catch (ExecutionException e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println(e.getMessage());
System.out.println(e.getLocalizedMessage());
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
objective += counts.myObjective;// tasks[i].getMyObjective();
for (int j = 0; j < dimension; j++) {
derivatives[j] += counts.myDerivatives[j];
}
nUnparasble += counts.unparsableTrees;
nIncorrectLL += counts.incorrectLLTrees;
}
if (spanPredictor!=null){
// System.out.println("donwscaling span derivatives");
int offset = dimension - spanGoldCounts.length;
double total = 0;
for (int rule=0; rule<spanGoldCounts.length; rule++){
// System.out.println(derivatives[offset+rule]+" "+spanGoldCounts[rule]);
total += derivatives[offset+rule];
derivatives[offset+rule] += spanGoldCounts[rule];
// if (derivatives[offset+rule]!=0)
// System.out.println("cant count! rule "+rule+" "+derivatives[offset+rule]+" "+spanGoldCounts[rule]);
// derivatives[offset+rule] = 0;
if (SloppyMath.isVeryDangerous(derivatives[offset+rule]))
System.out.print(derivatives[offset+rule]+" ");
}
System.out.println(total);
}
System.out.print(" done. ");
if (nUnparasble > 0)
System.out.println(nUnparasble + " trees were not parsable.");
if (nIncorrectLL > 0)
System.out.println(nIncorrectLL+" trees had a higher gold LL than all LL.");
// pool.shutdown();
System.out.print("\nThe objective was "+objective);
// double[] derivatives = computeDerivatives(expectedGCounts, expectedCounts);
lastUnregularizedDerivative = derivatives.clone();
switch (myRegularization){
case L2_REGULARIZATION:
objective = l2_regularize(objective, derivatives);
System.out.print(" and is "+objective+" after L2 regularization");
break;
case L1_REGULARIZATION:
objective = l1_regularize(objective, derivatives);
System.out.print(" and is "+objective+" after L1 regularization");
default:
break;
}
System.out.print(".\n");
objective *= -1.0; // flip sign since we are working with a minimizer rather than with a maximizer
for (int index = 0; index < derivatives.length; index++) {
// 'x' and 'derivatives' have same layout
derivatives[index] *= -1.0;
lastUnregularizedDerivative[index] *= -1.0;
}
lastValue = objective;
lastDerivative = derivatives;
//
// for (int i=0; i<50; i++){
// System.out.print(derivatives[derivatives.length-1-i]+" ");
// }
//
if (objective<bestObjectiveSoFar && !ConditionalTrainer.Options.dontSaveGrammarsAfterEachIteration){
bestObjectiveSoFar = objective;
ParserData pData = new ParserData(lexicon, grammar, spanPredictor, Numberer.getNumberers(), grammar.numSubStates, 1, 0, Binarization.RIGHT);
double val = objective;
if (val != 0.0) {
while (Math.abs(val) < 10000)
val *= 10.0;
}
int value = (int) val;
System.out.println("Saving grammar to "+outFileName+"-"+value+".");
if (!pData.Save(outFileName+"-"+value)) System.out.println("Saving failed!");
}
}
}
private boolean requiresUpdate(double[] proposed_x) {
if (this.x == null) return true;
for (int i = 0; i < x.length; i++) {
if (proposed_x[i]==Double.NaN){
System.out.println("Optimizer proposed "+x[i]);
proposed_x[i] = Double.NEGATIVE_INFINITY;
}
if (this.x[i] != proposed_x[i])
return true;
}
return false;
}
class Counts{
double myObjective;
double[] myDerivatives;
int unparsableTrees, incorrectLLTrees;
public Counts(double myObjective, double[] myDerivatives, int unpars, int incorr) {
this.myObjective = myObjective;
this.myDerivatives = myDerivatives;
this.unparsableTrees = unpars;
this.incorrectLLTrees = incorr;
}
}
class Calculator implements Callable{
// int nGrWeights;
ArrayParser gParser;
ConstrainedTwoChartsParser eParser;
StateSetTreeList myTrees;
String consName;
int myID;
int nCounts;
Counts myCounts;
boolean[][][][][] myConstraints;
int unparsableTrees, incorrectLLTrees;
boolean doNotProjectConstraints;
double[] myDerivatives;
Calculator(StateSetTreeList myT, String consN, int i,
Grammar gr, Lexicon lex, SpanPredictor sp, int dimension, boolean notProject){
// this.nGrWeights = nGrWeights;
this.nCounts = dimension;
this.consName = consN;
this.myTrees = myT;
this.doNotProjectConstraints = notProject;
this.myID = i;
gParser = new ArrayParser(gr, lex);
eParser = newEParser(gr, lex, sp);
}
/**
* @param gr
* @param lex
* @param boost
* @return
*/
protected ConstrainedTwoChartsParser newEParser(Grammar gr, Lexicon lex, SpanPredictor sp) {
if (!ConditionalTrainer.Options.hierarchicalChart) return new ConstrainedTwoChartsParser(gr, lex, sp);
return new ConstrainedHierarchicalTwoChartParser(gr, lex, sp, gr.finalLevel);
}
protected void loadConstraints(){
myConstraints = new boolean[myTrees.size()][][][][];
boolean[][][][][] curBlock = null;
int block = 0;
int i = 0;
if (consName==null) return;
for (int tree=0; tree<myTrees.size(); tree++){
if (curBlock == null || i >= curBlock.length){
int blockNumber = ((block*nProcesses)+myID);
curBlock = loadData(consName+"-"+blockNumber+".data");
block++;
i = 0;
System.out.print(".");
}
if (!doNotProjectConstraints) eParser.projectConstraints(curBlock[i], false);
myConstraints[tree] = curBlock[i];
i++;
if (myConstraints[tree].length!=myTrees.get(tree).getYield().size()){
System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
System.out.println("Sentence length and constraints length do not match!");
myConstraints[tree] = null;
}
}
}
/**
* The most important part of the classifier learning process! This method determines, for the given weight vector
* x, what the (negative) log conditional likelihood of the data is, as well as the derivatives of that likelihood
* wrt each weight parameter.
*/
public Counts call() {
double myObjective = 0;
myDerivatives = new double[dimension];
// double[] myDerivatives = new double[nCounts];
unparsableTrees = 0;
incorrectLLTrees = 0;
if (myConstraints==null) loadConstraints();
int i = -1;
int block = 0;
double totalBias = 0;
for (Tree<StateSet> stateSetTree : myTrees) {
i++;
List<StateSet> yield = stateSetTree.getYield();
boolean noSmoothing = false /*true*/, debugOutput = false;
// parse the sentence
boolean[][][][] cons = null;
if (consName!=null){
cons = myConstraints[i];
if (cons.length != yield.size()){
System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
System.out.println("Sentence length ("+yield.size()+") and constraints length ("+cons.length+") do not match!");
System.exit(-1);
}
}
double allLL = eParser.doConstrainedInsideOutsideScores(yield,cons,noSmoothing,null,null,false);
// compute the ll of the gold tree
double goldLL = (ConditionalTrainer.Options.hierarchicalChart) ?
eParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores):
gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores);
if (i%500==0) System.out.print(".");
if (!sanityCheckLLs(goldLL, allLL, stateSetTree)) {
myObjective += -1000;
continue;
}
if (false){ // compute exhaustive iS/oS to get exact expectations and then compute bias
double[] myExpectedCounts = new double[myDerivatives.length];
eParser.incrementExpectedCounts(linearizer, myExpectedCounts, yield);
double[] myExactExpectedCounts = new double[myDerivatives.length];
double exactLL = eParser.doConstrainedInsideOutsideScores(yield,null,noSmoothing,null,null,false);
eParser.incrementExpectedCounts(linearizer, myExactExpectedCounts, yield);
double bias = 0;
for (int ii=0; ii<myDerivatives.length; ii++){
double diff = myExpectedCounts[ii] - myExactExpectedCounts[ii];
bias += diff * diff;
}
totalBias += bias;
System.out.println(allLL + "\t" + exactLL + "\t" + bias);
}
eParser.incrementExpectedCounts(linearizer, myDerivatives, yield);
if (ConditionalTrainer.Options.hierarchicalChart)
eParser.incrementExpectedGoldCounts(linearizer, myDerivatives, stateSetTree);
else
gParser.incrementExpectedGoldCounts(linearizer, myDerivatives, stateSetTree);
myObjective += (goldLL - allLL);
// System.out.println(stateSetTree);
// double old = gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores);
// double old2 = eParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, eParser.spanScores);
// System.out.println(stateSetTree);
}
myCounts = new Counts(myObjective,myDerivatives,unparsableTrees,incorrectLLTrees);
totalBias /= myTrees.size();
System.out.println("\nAverage bias: "+totalBias+"\n");
System.out.print(" "+myID+" ");
return myCounts;
}
public boolean[][][][][] loadData(String fileName) {
boolean[][][][][] data = null;
try {
FileInputStream fis = new FileInputStream(fileName); // Load from file
GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
ObjectInputStream in = new ObjectInputStream(gzis); // Load objects
data = (boolean[][][][][])in.readObject(); // Read the mix of grammars
in.close(); // And close the stream.
gzis.close();
fis.close();
} catch (IOException e) {
System.out.println("IOException\n"+e);
return null;
} catch (ClassNotFoundException e) {
System.out.println("Class not found!");
return null;
}
return data;
}
/**
* @param goldLL
* @param allLL
* @param stateSetTree
* @return
*/
protected boolean sanityCheckLLs(double goldLL, double allLL, Tree<StateSet> stateSetTree) {
if (SloppyMath.isVeryDangerous(allLL) || SloppyMath.isVeryDangerous(goldLL)) {
unparsableTrees++;
return false;
}
if (goldLL - allLL > 1.0e-4){
System.out.println("Something is wrong! The gold LL is " + goldLL + " and the all LL is " + allLL);//+"\n"+sentence+"\n"+stateSetTree);
System.out.println(stateSetTree);
incorrectLLTrees++;
return false;
}
return true;
}
}
public double l2_regularize(double objective, double[] derivatives){
// Incorporate penalty terms (regularization) into the objective and derivatives
if (SloppyMath.isVeryDangerous(objective)) return objective;
double sigma2 = sigma*sigma;
double penalty = 0.0;
for (int index = 0; index < x.length; index++) {
//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
penalty += x[index]*x[index];
}
// System.out.print(" penalty="+penalty);
objective -= penalty / (2*sigma2);
for (int index = 0; index < x.length; index++) {
// 'x' and 'derivatives' have same layout
//if (lastX[index]==10000 || Double.isInfinite(lastX[index])) continue;
derivatives[index] -= x[index]/sigma2;
if (SloppyMath.isVeryDangerous(derivatives[index])){
System.out.println("Setting regularized derivative to zero because it is Inf.");
derivatives[index] = 0;
}
}
return objective;
}
public double l1_regularize(double objective, double[] derivatives){
// Incorporate penalty terms (regularization) into the objective and derivatives
if (SloppyMath.isVeryDangerous(objective)) return objective;
double sigma2 = sigma*sigma;
double sigma2span = 1;//(sigma-2)*(sigma-2);
double sigma2lex = sigma2;//1;//1;//(sigma-2)*(sigma-2);
int ind = 0;
int penaltyGr=0, penaltyLex=0, penaltySpan=0;
for (int i = 0; i < nGrammarWeights; i++) {
penaltyGr += Math.abs(x[ind++]);
}
penaltyGr /= (2*sigma2);
for (int i = 0; i < nLexiconWeights; i++) {
penaltyLex += Math.abs(x[ind++]);
}
penaltyLex /= (2*sigma2lex);
for (int i = 0; i < nSpanWeights; i++) {
penaltySpan += Math.abs(x[ind++]);
}
penaltySpan /= (2*sigma2span);
objective -= (penaltyGr + penaltyLex + penaltySpan);
int index = 0;
for (int i = 0; i < nGrammarWeights; i++) {
double mySigma = sigma2;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
for (int i = 0; i < nLexiconWeights; i++) {
double mySigma = sigma2lex;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
for (int i = 0; i < nSpanWeights; i++) {
double mySigma = sigma2span;
if (x[index] < 0) derivatives[index] -= -1.0/mySigma;
else if (x[index] > 0) derivatives[index] -= 1.0/mySigma;
else {
if (derivatives[index] < -1.0/mySigma) derivatives[index] -= 1.0/mySigma;
else if (derivatives[index] > 1.0/mySigma) derivatives[index] -= -1.0/mySigma;
else { derivatives[index] = 0; lastUnregularizedDerivative[index] = 0; } // probably already 0;
}
if (SloppyMath.isVeryDangerous(derivatives[index])||Math.abs(derivatives[index])>1.0e10){
System.out.println("Setting regularized derivative to zero because it is "+derivatives[index]);
derivatives[index] = 0; lastUnregularizedDerivative[index] = 0;
}
index++;
}
return objective;
}
//
// public double[] computeDerivatives(double[] expectedGoldCounts, double[] expectedCounts){
// double[] derivatives = new double[dimension()];
//
// int nDangerous = 0;
// if (spanPredictor!=null){
// int offset = dimension - spanGoldCounts.length;
// for (int rule=0; rule<spanGoldCounts.length; rule++){
// expectedGoldCounts[offset+rule] = spanGoldCounts[rule];
// }
// }
// for (int rule=0; rule<derivatives.length;rule++){
// derivatives[rule] = (expectedGoldCounts[rule]-expectedCounts[rule]);
// if (SloppyMath.isVeryDangerous(derivatives[rule])||Math.abs(derivatives[rule])>1.0e10){
// nDangerous++;
// System.out.println("Setting derivative to zero because it is "+expectedGoldCounts[rule]+" - "+expectedCounts[rule]+" = "+derivatives[rule]);
// derivatives[rule] = 0;
// }
// }
//
// if (nDangerous>0) System.out.println("Set "+nDangerous+" derivatives to 0 since they were dangerous.");
// return derivatives;
// }
public ParsingObjectiveFunction() {
}
public ParsingObjectiveFunction(Linearizer linearizer, StateSetTreeList trainTrees,
double sigma, int regularization, String consName, int nProc, String outName,
boolean doNotProjectConstraints, boolean combinedLexicon) {
this.sigma = sigma;
this.myRegularization = regularization;
this.grammar = linearizer.getGrammar();//.copyGrammar();
this.lexicon = linearizer.getLexicon();//.copyLexicon();
this.spanPredictor = linearizer.getSpanPredictor();
this.linearizer = linearizer;
this.outFileName = outName;
this.dimension = linearizer.dimension();
nGrammarWeights = linearizer.getNGrammarWeights();
nLexiconWeights = linearizer.getNLexiconWeights();
nSpanWeights = linearizer.getNSpanWeights();
if (spanPredictor!=null)
this.spanGoldCounts = spanPredictor.countGoldSpanFeatures(trainTrees);
int nTreesPerBlock = trainTrees.size()/nProc;
this.consBaseName = consName;
boolean[][][][][] tmp = edu.berkeley.nlp.PCFGLA.ParserConstrainer.loadData(consName+"-0.data");
if (tmp!=null) nTreesPerBlock = tmp.length;
// split the trees into chunks
this.nProcesses = nProc;
trainingTrees = new StateSetTreeList[nProcesses];
// allowedStates = new ArrayList[nProcesses];
for (int i=0; i<nProcesses; i++){
trainingTrees[i] = new StateSetTreeList();
// allowedStates[i] = new ArrayList<boolean[][][][]>();
}
int block = -1;
int inBlock = 0;
for (int i=0; i<trainTrees.size(); i++){
if (i%nTreesPerBlock==0) {
block++;
// System.out.println(inBlock);
inBlock = 0;
}
trainingTrees[block%nProcesses].add(trainTrees.get(i));
inBlock++;
// if (cons!=null) allowedStates[i%nProcesses].add(ArrayUtil.clone(cons[i]));
}
for (int i=0; i<nProcesses; i++){
System.out.println("Process "+i+" has "+trainingTrees[i].size()+" trees.");
}
trainTrees = null;
pool = Executors.newFixedThreadPool(nProcesses);//CachedThreadPool();
tasks = new Calculator[nProcesses];
for (int i=0; i<nProcesses; i++){
tasks[i] = newCalculator(doNotProjectConstraints, i);
}
this.bestObjectiveSoFar = Double.POSITIVE_INFINITY;
}
public void shutdown(){
pool.shutdown();
}
/**
* @param doNotProjectConstraints
* @param i
* @return
*/
protected Calculator newCalculator(boolean doNotProjectConstraints, int i) {
return new Calculator(trainingTrees[i],consBaseName,i, grammar, lexicon, spanPredictor, dimension, doNotProjectConstraints);
}
public double[] getCurrentWeights(){
return linearizer.getLinearizedWeights();
}
/* (non-Javadoc)
* @see edu.berkeley.nlp.classify.ObjectiveFunction#getLogProbabilities(edu.berkeley.nlp.classify.EncodedDatum, double[], edu.berkeley.nlp.classify.Encoding, edu.berkeley.nlp.classify.IndexLinearizer)
*/
public <F, L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
// TODO Auto-generated method stub
return null;
}
/**
* @param newSigma
*/
public void setSigma(double newSigma) {
sigma = newSigma;
x = null;
bestObjectiveSoFar = Double.POSITIVE_INFINITY;
}
}