package edu.berkeley.nlp.PCFGLA;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import edu.berkeley.nlp.PCFGLA.Corpus.TreeBankType;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentBits;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.discPCFG.ConditionalMerger;
import edu.berkeley.nlp.discPCFG.DefaultLinearizer;
import edu.berkeley.nlp.discPCFG.HiearchicalAdaptiveLinearizer;
import edu.berkeley.nlp.discPCFG.HierarchicalLinearizer;
import edu.berkeley.nlp.discPCFG.Linearizer;
import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.OW_LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
/**
* Reads in the Penn Treebank and generates N_GRAMMARS different grammars.
*
* @author Slav Petrov
*/
public class ConditionalTrainer {
/**
* @author adampauls
*
*/
public static interface ParsingObjectFunctionFactory {
/**
* @param opts
* @param outFileName
* @param linearizer
* @param trainStateSetTrees
* @param regularize
* @param newSigma
* @return
*/
public ParsingObjectiveFunction newParsingObjectiveFunction(Options opts,
String outFileName, Linearizer linearizer,
StateSetTreeList trainStateSetTrees, int regularize, double newSigma);
}
public static class Options {
@Option(name = "-out", usage = "Output File for Grammar")
public String outFileName;
@Option(name = "-outDir", usage = "Output Directory for Grammar")
public String outDir;
@Option(name = "-path", usage = "Path to Corpus")
public String path = null;
@Option(name = "-SMcycles", usage = "The number of split&merge iterations (Default: 6)")
public int numSplits = 6;
@Option(name = "-mergingPercentage", usage = "Merging percentage (Default: 0.0)")
public double mergingPercentage = 0;
@Option(name = "-baseline", usage = "Just read of the MLE baseline grammar")
public boolean baseline = false;
@Option(name = "-treebank", usage = "Language: WSJ, CHNINESE, GERMAN, CONLL, SINGLEFILE (Default: ENGLISH)")
public TreeBankType treebank = TreeBankType.WSJ;
@Option(name = "-splitMaxIt", usage = "Maximum number of EM iterations after splitting (Default: 50)")
public int splitMaxIterations = 100;
@Option(name = "-splitMinIt", usage = "Minimum number of EM iterations after splitting (Default: 50)")
public int splitMinIterations = 50;
@Option(name = "-mergeMaxIt", usage = "Maximum number of EM iterations after merging (Default: 20)")
public int mergeMaxIterations = 20;
@Option(name = "-mergeMinIt", usage = "Minimum number of EM iterations after merging (Default: 20)")
public int mergeMinIterations = 20;
@Option(name = "-di", usage = "The number of allowed iterations in which the validation likelihood drops. (Default: 6)")
public int di = 6;
@Option(name = "-trfr", usage = "The fraction of the training corpus to keep (Default: 1.0)\n")
public double trainingFractionToKeep = 1.0;
@Option(name = "-filter", usage = "Filter rules with prob below this threshold (Default: 1.0e-30)")
public double filter = 1.0e-30;
@Option(name = "-smooth", usage = "Type of grammar smoothing used.")
public String smooth = "NoSmoothing";
@Option(name = "-b", usage = "LEFT/RIGHT Binarization (Default: RIGHT)")
public Binarization binarization = Binarization.RIGHT;
@Option(name = "-noSplit", usage = "Don't split - just load and continue training an existing grammar (true/false) (Default:false)")
public boolean noSplit = false;
@Option(name = "-initializeZero", usage = "Initialize conditional weights with zero")
public boolean initializeZero = false;
@Option(name = "-in", usage = "Input File for Grammar")
public String inFile = null;
@Option(name = "-randSeed", usage = "Seed for random number generator")
public int randSeed = 8;
@Option(name = "-sep", usage = "Set merging threshold for grammar and lexicon separately (Default: false)")
public boolean separateMergingThreshold = false;
@Option(name = "-hor", usage = "Horizontal Markovization (Default: 0)")
public int horizontalMarkovization = 0;
@Option(name = "-sub", usage = "Number of substates to split (Default: 1)")
public int nSubStates = 1;
@Option(name = "-ver", usage = "Vertical Markovization (Default: 1)")
public int verticalMarkovization = 1;
@Option(name = "-v", usage = "Verbose/Quiet (Default: Quiet)\n")
public boolean verbose = false;
@Option(name = "-r", usage = "Level of Randomness at init (Default: 1)\n")
public double randomization = 1.0;
@Option(name = "-sm1", usage = "Lexicon smoothing parameter 1")
public double smoothingParameter1 = 0.5;
@Option(name = "-sm2", usage = "Lexicon smoothing parameter 2)")
public double smoothingParameter2 = 0.1;
@Option(name = "-rare", usage = "Rare word threshold (Default 4)")
public int rare = 4;
@Option(name = "-spath", usage = "Whether or not to store the best path info (true/false) (Default: true)")
public boolean findClosedUnaryPaths = true;
@Option(name = "-unkT", usage = "Threshold for unknown words (Default: 5)")
public int unkThresh = 5;
@Option(name = "-doConditional", usage = "Do conditional training")
public boolean doConditional = false;
@Option(name = "-regularize", usage = "Regularize during optimization: 0-no regularization, 1-l1, 2-l2")
public int regularize = 0;
@Option(name = "-onlyMerge", usage = "Do only a conditional merge")
public boolean onlyMerge = false;
@Option(name = "-sigma", usage = "Regularization coefficient")
public double sigma = 1.0;
@Option(name = "-cons", usage = "File with constraints")
public String cons = null;
@Option(name = "-nProcess", usage = "Distribute on that many cores")
public int nProcess = 1;
@Option(name = "-doNOTprojectConstraints", usage = "Do NOT project constraints")
public boolean doNOTprojectConstraints = false;
@Option(name = "-section", usage = "Which section of the corpus to process.")
public String section = "train";
@Option(name = "-outputLog", usage = "Print output to this file rather than STDOUT.")
public String outputLog = null;
@Option(name = "-maxL", usage = "Skip sentences which are longer than this.")
public int maxL = 10000;
@Option(name = "-nChunks", usage = "Store constraints in that many files.")
public int nChunks = 1;
@Option(name = "-logT", usage = "Log threshold for pruning")
public double logT = -10;
@Option(name = "-lasso", usage="Start of by regularizing less and make the regularization stronger with time")
public boolean lasso = false;
@Option(name = "-hierarchical", usage="Use hierarchical rules")
public boolean hierarchical = false;
@Option(name = "-keepGoldTreeAlive", usage="Don't prune the gold train when computing constraints")
public boolean keepGoldTreeAlive = false;
@Option(name = "-flattenParameters", usage="Flatten parameters to reduce overconfidence")
public double flattenParameters = 1.0;
@Option(name = "-usePosteriorTraining", usage="Adam's new objective function")
public boolean usePosteriorTraining = false;
@Option(name = "-dontLoad", usage="Don't load anything from the pipeline")
public boolean dontLoad = false;
@Option(name = "-predefinedMaxSplit", usage="Use predifined number of subcategories")
public boolean predefinedMaxSplit = false;
@Option(name = "-collapseUnaries", usage="Dont throw away trees with unaries, just collapse the unary chains")
public boolean collapseUnaries = false;
@Option(name = "-connectedLexicon", usage="Score each word with the sum of its score and its signature score")
public boolean connectedLexicon = false;
@Option(name = "-adaptive", usage="Use adpatively refined rules")
public boolean adaptive = false;
@Option(name = "-checkDerivative", usage="Check the derivative of the objective function against an estimate with finite difference")
public boolean checkDerivative = false;
@Option(name = "-initRandomness", usage="Amount of randomness to initialize the grammar with")
public double initRandomness = 1.0;
@Option(name = "-markUnaryParents", usage="Filter all training trees with any unaries (other than lexical and ROOT productions)")
public boolean markUnaryParents = false;
@Option(name = "-filterAllUnaries", usage="Mark any unary parent with a ^u")
public boolean filterAllUnaries = false;
@Option(name = "-filterStupidFrickinWHNP", usage="Temp hack!")
public boolean filterStupidFrickinWHNP = false;
@Option(name = "-initializeDir", usage="Temp hack!")
public String initializeDir = null;
@Option(name = "-allPosteriorsWeight", usage="Weight for the all posteriors regularizer")
public double allPosteriorsWeight = 0.0;
@Option(name="-dontSaveGrammarsAfterEachIteration")
public static boolean dontSaveGrammarsAfterEachIteration = false;
@Option(name="-hierarchicalChart")
public static boolean hierarchicalChart = false;
@Option(name="-testAll", usage="Test grammars after each iteration, proceed by splitting the best")
public boolean testAll = false;
@Option(name="-lockGrammar", usage="Lock grammar weights, learn only span feature weights")
public static boolean lockGrammar = false;
@Option(name="-featurizedLexicon", usage="Use featurized lexicon (no fixed signature classes")
public boolean featurizedLexicon = false;
@Option(name = "-spanFeatures", usage="Use span features")
public boolean spanFeatures = false;
@Option(name="-useFirstAndLast", usage="Use first and last span words as span features")
public static boolean useFirstAndLast = false;
@Option(name="-usePreviousAndNext", usage="Use previous and next span words as span features")
public static boolean usePreviousAndNext = false;
@Option(name="-useBeginAndEndPairs", usage="Use begin and end word-pairs as span features")
public static boolean useBeginAndEndPairs = false;
@Option(name="-useSyntheticClass", usage="Distiguish between real and synthetic constituents")
public static boolean useSyntheticClass = false;
@Option(name="-usePunctuation", usage="Use punctuation cues")
public static boolean usePunctuation = false;
@Option(name="-minFeatureFrequency", usage="Use punctuation cues")
public static int minFeatureFrequency = 0;
@Option(name = "-lbfgsHistorySize", usage = "Max size of L-BFGS history (use -1 for defaults)")
public int lbfgsHistorySize = -1;
//-spanFeatures -usePunctuation -useSyntheticClass -useFirstAndLast -usePreviousAndNext -useBeginAndEndPairs
}
private static ParsingObjectFunctionFactory parsingObjectFunctionFactory = new ParsingObjectFunctionFactory()
{
public ParsingObjectiveFunction newParsingObjectiveFunction(Options opts,
String outFileName, Linearizer linearizer,
StateSetTreeList trainStateSetTrees, int regularize, double newSigma) {
return ConditionalTrainer.newParsingObjectiveFunction(opts, outFileName,
linearizer,
trainStateSetTrees, regularize, newSigma);
}
};
public static void setParsingObjectiveFunctionFactory(
ParsingObjectFunctionFactory fact) {
parsingObjectFunctionFactory = fact;
}
public static void main(String[] args) {
OptionParser optParser = new OptionParser(Options.class);
Options opts = (Options) optParser.parse(args, false);
// provide feedback on command-line arguments
System.out.println("Calling ConditionalTrainer with " + optParser.getPassedInOptions());
String path = opts.path;
// int lang = opts.lang;
System.out.println("Loading trees from "+path+" and using language "+opts.treebank);
double trainingFractionToKeep = opts.trainingFractionToKeep;
int maxSentenceLength = opts.maxL;
System.out.println("Will remove sentences with more than "+maxSentenceLength+" words.");
Binarization binarization = opts.binarization;
System.out.println("Using "+ binarization.name() + " binarization.");// and "+annotateString+".");
double randomness = opts.randomization;
System.out.println("Using a randomness value of "+randomness);
String outFileName = opts.outFileName;
if (outFileName==null) {
System.out.println("Output File name is required.");
System.exit(-1);
}
else System.out.println("Using grammar output file "+outFileName+".");
GrammarTrainer.VERBOSE = opts.verbose;
GrammarTrainer.RANDOM = new Random(opts.randSeed);
System.out.println("Random number generator seeded at "+opts.randSeed+".");
boolean manualAnnotation = false;
boolean baseline = opts.baseline;
boolean noSplit = opts.noSplit;
int numSplitTimes = opts.numSplits;
if (baseline) numSplitTimes = 0;
String splitGrammarFile = opts.inFile;
int allowedDroppingIters = opts.di;
int maxIterations = opts.splitMaxIterations;
int minIterations = opts.splitMinIterations;
if (minIterations>0)
System.out.println("I will do at least "+minIterations+" iterations.");
double[] smoothParams = {opts.smoothingParameter1,opts.smoothingParameter2};
System.out.println("Using smoothing parameters "+smoothParams[0]+" and "+smoothParams[1]);
if (opts.connectedLexicon) System.out.println("Using connected lexicon.");
if (opts.featurizedLexicon) System.out.println("Using featuized lexicon.");
// boolean allowMoreSubstatesThanCounts = false;
boolean findClosedUnaryPaths = opts.findClosedUnaryPaths;
Corpus corpus = new Corpus(path,opts.treebank,trainingFractionToKeep,false);
List<Tree<String>> trainTrees = Corpus.binarizeAndFilterTrees(corpus
.getTrainTrees(), opts.verticalMarkovization,
opts.horizontalMarkovization, maxSentenceLength, binarization, manualAnnotation,GrammarTrainer.VERBOSE, opts.markUnaryParents);
List<Tree<String>> validationTrees = Corpus.binarizeAndFilterTrees(corpus
.getValidationTrees(), opts.verticalMarkovization,
opts.horizontalMarkovization, maxSentenceLength, binarization, manualAnnotation,GrammarTrainer.VERBOSE, opts.markUnaryParents);
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
if (opts.collapseUnaries) System.out.println("Collpasing unary chains.");
if (trainTrees!=null)trainTrees = Corpus.filterTreesForConditional(trainTrees, opts.filterAllUnaries,opts.filterStupidFrickinWHNP,opts.collapseUnaries);
if (validationTrees!=null) validationTrees = Corpus.filterTreesForConditional(validationTrees,opts.filterAllUnaries,opts.filterStupidFrickinWHNP,opts.collapseUnaries);
int nTrees = trainTrees.size();
System.out.println("There are "+nTrees+" trees in the training set.");
double filter = opts.filter;
short nSubstates = (short)opts.nSubStates;
short[] numSubStatesArray = initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates);
if (baseline) {
short one = 1;
Arrays.fill(numSubStatesArray, one);
System.out.println("Training just the baseline grammar (1 substate for all states)");
randomness = 0.0f;
}
if (GrammarTrainer.VERBOSE){
for (int i=0; i<numSubStatesArray.length; i++){
System.out.println("Tag "+(String)tagNumberer.object(i)+" "+i);
}
}
//initialize lexicon and grammar
SimpleLexicon lexicon = null, maxLexicon = null, previousLexicon = null;
Grammar grammar = null, maxGrammar = null, previousGrammar = null;
SpanPredictor spanPredictor = null;
double maxLikelihood = Double.NEGATIVE_INFINITY;
// EM: iterate until the validation likelihood drops for four consecutive
// iterations
int iter = 0;
int droppingIter = 0;
// If we are splitting, we load the old grammar and start off by splitting.
int startSplit = 0;
Linearizer linearizer = null;
if (splitGrammarFile!=null) {
System.out.println("Loading old grammar from "+splitGrammarFile);
startSplit = 1; // we've already trained the grammar
ParserData pData = ParserData.Load(splitGrammarFile);
Numberer.setNumberers(pData.getNumbs());
tagNumberer = Numberer.getGlobalNumberer("tags");
boolean noUnaryChains = true;
previousGrammar = pData.gr.copyGrammar(noUnaryChains);
previousLexicon = (SimpleLexicon)pData.lex.copyLexicon();
maxGrammar = pData.gr.copyGrammar(noUnaryChains);
maxLexicon = (SimpleLexicon)pData.lex.copyLexicon();
spanPredictor = pData.getSpanPredictor();
if ( opts.hierarchical && previousGrammar.numSubStates[1]==1){ // the previous grammar was the baseline grammar
System.out.println("Converting grammar to hierarchical rules.");
// convert it to a hierarchical grammar
if (opts.adaptive){
// maxGrammar = new HierarchicalAdaptiveGrammar(previousGrammar);
// maxLexicon = new HierarchicalFullyConnectedAdaptiveLexicon(previousLexicon, opts.unkThresh);
} else {
maxGrammar = new HierarchicalGrammar(previousGrammar);
if (opts.connectedLexicon){
maxLexicon = new HierarchicalFullyConnectedLexicon(previousLexicon, opts.unkThresh);
}
else
maxLexicon = new HierarchicalLexicon(previousLexicon);
}
}
if (!opts.noSplit){
System.out.println("Splitting the input grammar and lexicon");
boolean allowMoreSubstatesThanCounts = true;//false;
StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer,trainStateSetTrees);
int[] counts = corpusStatistics.getSymbolCounts();
if (opts.predefinedMaxSplit){
System.out.println("Using predefnied max number of subcategories!");
allowMoreSubstatesThanCounts = false;
int[] tmp = new int[]{1, 18, 26, 62, 64, 49, 21, 2, 58, 6, 35, 15, 6, 5, 59, 1, 46, 33, 21, 61, 36, 29, 7, 28, 21, 59, 4, 37, 39, 1, 6, 1, 2, 17, 28, 25, 2, 3, 1, 1, 1, 3, 2, 6, 3, 6, 2, 2, 1, 2, 2, 2, 9, 2, 2, 6, 6, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2, 5, 3, 3, 5, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2};
if (tmp.length != counts.length) throw new Error("counts do not match");
counts = tmp;
}
System.out.println(Arrays.toString(counts));
int mode = (opts.hierarchical) ? 2 : 1;
maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, mode);
maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, mode);
}
if (opts.hierarchical){
System.out.println("Using hierarchical rules!");
short finalLevel = (short)(Math.log(maxGrammar.numSubStates[1])/Math.log(2));
System.out.println("The final level of refinement will be: "+finalLevel);
maxGrammar.finalLevel = finalLevel;
if (opts.adaptive){
linearizer = new HiearchicalAdaptiveLinearizer(maxGrammar, maxLexicon, spanPredictor, finalLevel);
} else
linearizer = new HierarchicalLinearizer(maxGrammar, maxLexicon, spanPredictor, finalLevel);
}
else linearizer = new DefaultLinearizer(maxGrammar, maxLexicon, spanPredictor);
numSubStatesArray = maxGrammar.numSubStates;
previousGrammar = grammar = maxGrammar;
previousLexicon = lexicon = maxLexicon;
System.out.println("Loading old grammar complete.");
if (noSplit){
System.out.println("Will NOT split the loaded grammar.");
startSplit=0;
}
}
double mergingPercentage = opts.mergingPercentage;
boolean separateMergingThreshold = opts.separateMergingThreshold;
if (mergingPercentage>0){
System.out.println("Will merge "+(int)(mergingPercentage*100)+"% of the splits in each round.");
System.out.println("The threshold for merging lexical and phrasal categories will be set separately: "+separateMergingThreshold);
}
StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer);//deletePC);
// replaces rare words with their signatures
if (!(opts.connectedLexicon)||!opts.doConditional||opts.unkThresh<0){
System.out.println("Replacing words which have been seen less than "+opts.unkThresh+" times with their signature.");
Corpus.replaceRareWords(trainStateSetTrees,new SimpleLexicon(numSubStatesArray,-1), Math.abs(opts.unkThresh));
}
if (splitGrammarFile!=null)
maxLexicon.labelTrees(trainStateSetTrees);
if (splitGrammarFile!=null) lexicon = maxLexicon;
if (splitGrammarFile!=null && spanPredictor==null && opts.spanFeatures){
System.out.println("Adding a span predictor since there was none!");
spanPredictor = new SpanPredictor(maxLexicon.nWords, trainStateSetTrees, tagNumberer, maxLexicon.wordIndexer);
linearizer = new HiearchicalAdaptiveLinearizer(maxGrammar, maxLexicon, spanPredictor, maxGrammar.finalLevel);
}
// get rid of the old trees
trainTrees = null;
validationTrees = null;
corpus = null;
System.gc();
// If we're training without loading a split grammar, then we run once without splitting.
if (splitGrammarFile==null) {
int n = 0;
grammar = new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter);
lexicon = new SimpleLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(),filter, trainStateSetTrees);
boolean secondHalf = false;
for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
secondHalf = (n++>nTrees/2.0);
lexicon.trainTree(stateSetTree, randomness, null, secondHalf,false,opts.rare);
grammar.tallyUninitializedStateSetTree(stateSetTree);
}
lexicon.optimize();
grammar.optimize(randomness);
// System.out.println(grammar);
boolean noUnaryChains = true;
Grammar grammar2 = grammar.copyGrammar(noUnaryChains);
SimpleLexicon lexicon2 = lexicon.copyLexicon();
System.out.println("Known word cut-off at "+opts.unkThresh+" occurences.");
if (opts.adaptive){
System.out.println("Using hierarchical adaptive grammar and lexicon.");
grammar2 = new HierarchicalAdaptiveGrammar(grammar2);
lexicon2 = (opts.featurizedLexicon) ?
new HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh):
new HierarchicalFullyConnectedAdaptiveLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh);
if (opts.spanFeatures)
spanPredictor = new SpanPredictor(lexicon2.nWords, trainStateSetTrees, tagNumberer, lexicon2.wordIndexer);
linearizer = new HiearchicalAdaptiveLinearizer(grammar2, lexicon2, spanPredictor, 0);
} else if (opts.connectedLexicon&&opts.doConditional) {
lexicon2 = new HierarchicalFullyConnectedLexicon(numSubStatesArray,-1,smoothParams, new NoSmoothing(), trainStateSetTrees, opts.unkThresh);
linearizer = new DefaultLinearizer(grammar2, lexicon2, spanPredictor);
} else {
linearizer = new DefaultLinearizer(grammar2, lexicon2, spanPredictor);
}
if (opts.initializeZero) System.out.println("Initializing weigths with zero!");
Random rand = GrammarTrainer.RANDOM;
double[] init = linearizer.getLinearizedWeights();
if (opts.initializeZero) {
// Arrays.fill(init, 0);
for (int i=0; i<init.length; i++){
init[i] = opts.initRandomness * rand.nextDouble()/100;
}
}
linearizer.delinearizeWeights(init);
grammar2 = linearizer.getGrammar();
lexicon2 = linearizer.getLexicon();
spanPredictor = linearizer.getSpanPredictor();
grammar2.splitRules();
previousGrammar = maxGrammar = grammar = grammar2; //needed for baseline - when there is no EM loop
previousLexicon = maxLexicon = lexicon = lexicon2;
}
if (opts.doConditional){
if (opts.onlyMerge){
System.out.println("Will do only a conditional merge.");
ConditionalMerger merger = new ConditionalMerger(opts.nProcess, opts.cons, trainStateSetTrees,
grammar, lexicon, opts.mergingPercentage, opts.outFileName);
merger.mergeGrammarAndLexicon();
System.exit(1);
}
ParsingObjectiveFunction objective = null;
int regularize = opts.regularize;
int iterations = opts.splitMaxIterations;
double sigma = opts.sigma;
if (regularize>0){
System.out.println("Regularizing with sigma="+sigma);
}
LBFGSMinimizer minimizer = null;
int maxIter = (opts.noSplit) ? 2 : 4;
for (int it=1; it<maxIter; it++){
if (opts.regularize==1) minimizer = new OW_LBFGSMinimizer(iterations);
else minimizer = new LBFGSMinimizer(iterations);
if (opts.lbfgsHistorySize >= 0)
minimizer.setMaxHistorySize(opts.lbfgsHistorySize);
double newSigma = sigma;
if (opts.lasso && !opts.noSplit){
newSigma = sigma + 3 - it;
System.out.println("The regularization parameter for this round will be: "+newSigma);
}
if (it==1) {
objective = parsingObjectFunctionFactory.newParsingObjectiveFunction(opts, outFileName, linearizer, trainStateSetTrees, regularize, newSigma);
minimizer.setMinIteratons(15);
} else {
minimizer.setMinIteratons(5);
}
objective.setSigma(newSigma);
double[] weights = objective.getCurrentWeights();
if (it == 1 && opts.checkDerivative)
{
System.out.print("\nChecking derivative: ");
double f = objective.valueAt(weights);
double[] deriv = objective.derivativeAt(weights);
double[] fDif = deriv.clone();
final double h = 1e-4;
for (int i = 0; i < 1; ++i)
{
double[] newWeights = weights.clone();
newWeights[i] += h;
double fplush = objective.valueAt(newWeights);
double finiteDif = (fplush - f) / h;
if (finiteDif - deriv[i] > 0.1)
{
System.out.println("Derivative is whack!");
}
fDif[i] = finiteDif;
}
System.out.println("done");
}
System.out.print("\nChecking weights: ");
int invalid = 0;
for (int i=0; i<weights.length; i++){
if (SloppyMath.isVeryDangerous(weights[i])){
invalid++;
weights[i] = 0;
}
}
System.out.print(invalid+" out of "+weights.length+" features had -Inf weight and have been set to 0.\n");
// objective.updateGoldCountsNextRound();
//objective = new ConstrainedParsingObjectiveFunction(grammar, startIndexGrammar, lexicon, startIndexLexicon, trainStateSetTrees, sigma, consFileName, regularize,false, nRules, nRules2);
System.out.println("In the "+it+". EM-like Iteration.");
weights = minimizer.minimize(objective, weights, 1e-4);
linearizer.delinearizeWeights(weights);
grammar = linearizer.getGrammar();
lexicon = linearizer.getLexicon();
spanPredictor = linearizer.getSpanPredictor();
ParserData pData = new ParserData(maxLexicon, maxGrammar, spanPredictor, Numberer.getNumberers(), numSubStatesArray, opts.verticalMarkovization, opts.horizontalMarkovization, binarization);
System.out.println("Saving grammar to "+outFileName+"-"+it+".");
if (pData.Save(outFileName+"-"+it)) System.out.println("Saving successful.");
else System.out.println("Saving failed!");
}
if (true){
if (opts.hierarchical&&splitGrammarFile!=null){
System.out.println("Collapsing unused parameters.");
HierarchicalGrammar hrGrammar = (HierarchicalGrammar)maxGrammar;
HierarchicalLexicon hrLexicon = (HierarchicalLexicon)maxLexicon;
if (opts.mergingPercentage!=-1){
hrGrammar.mergeGrammar();
hrLexicon.mergeLexicon();
}
maxGrammar = hrGrammar;
maxLexicon = hrLexicon;
}
objective.shutdown();
ParserData pData = new ParserData(maxLexicon, maxGrammar, spanPredictor, Numberer.getNumberers(), numSubStatesArray, opts.verticalMarkovization, opts.horizontalMarkovization, binarization);
System.out.println("Saving grammar to "+outFileName+".");
if (pData.Save(outFileName)) System.out.println("Saving successful.");
else System.out.println("Saving failed!");
return;
//System.exit(1);
}
}
boolean allowMoreSubstatesThanCounts = true;
// the main loop: split and train the grammar
for (int splitIndex = startSplit; splitIndex < 3*numSplitTimes; splitIndex++) {
// now do either a merge or a split and the end a smooth
// on odd iterations merge, on even iterations split
String opString = "";
if (splitIndex%3==2){//(splitIndex==numSplitTimes*2){
if (opts.onlyMerge) continue;
if (opts.smooth.equals("NoSmoothing")) continue;
System.out.println("Setting smoother for grammar and lexicon.");
Smoother grSmoother = new SmoothAcrossParentBits(0.01,maxGrammar.splitTrees);
Smoother lexSmoother = new SmoothAcrossParentBits(0.1,maxGrammar.splitTrees);
// Smoother grSmoother = new SmoothAcrossParentSubstate(0.01);
// Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1);
maxGrammar.setSmoother(grSmoother);
maxLexicon.setSmoother(lexSmoother);
minIterations = maxIterations = 10;
opString = "smoothing";
}
else if (splitIndex%3==0) {
if (opts.onlyMerge) continue;
// the case where we split
if (!opts.noSplit){
System.out.println("Before splitting, we have a total of "+maxGrammar.totalSubStates()+" substates.");
CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer,trainStateSetTrees);
int[] counts = corpusStatistics.getSymbolCounts();
maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts,0);
maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts,0);
Smoother grSmoother = new NoSmoothing();
Smoother lexSmoother = new NoSmoothing();
maxGrammar.setSmoother(grSmoother);
maxLexicon.setSmoother(lexSmoother);
System.out.println("After splitting, we have a total of "+maxGrammar.totalSubStates()+" substates.");
System.out.println("Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!");
}
opString = "splitting";
maxIterations = opts.splitMaxIterations;
minIterations = opts.splitMinIterations;
}
else {
if (mergingPercentage==0) continue;
// the case where we merge
double[][] mergeWeights = GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon,trainStateSetTrees);
double[][][] deltas = GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees);
boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas,separateMergingThreshold,mergingPercentage,maxGrammar);
// merges grammar and lexicon and returns the merged grammar while the lexicon is merged in place
grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights);
lexicon = maxLexicon;
short[] newNumSubStatesArray = grammar.numSubStates;
trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false);
validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false);
// retrain lexicon to finish the lexicon merge (updates the unknown words model)...
// lexicon = new Lexicon(newNumSubStatesArray,Lexicon.DEFAULT_SMOOTHING_CUTOFF, maxLexicon.smooth, maxLexicon.smoother, maxLexicon.threshold);
// boolean updateOnlyLexicon = true;
// double trainingLikelihood = ConditionalTrainer.doOneEStep(grammar, maxLexicon, null, lexicon, trainStateSetTrees, updateOnlyLexicon);
// System.out.println("The training LL is "+trainingLikelihood);
// lexicon.optimize();//Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // M Step
GrammarMerger.printMergingStatistics(maxGrammar, grammar);
opString = "merging";
maxGrammar = grammar; maxLexicon = lexicon;
maxIterations = opts.mergeMaxIterations;
minIterations = opts.mergeMinIterations;
}
//update the substate dependent objects
previousGrammar = grammar = maxGrammar;
previousLexicon = lexicon = maxLexicon;
droppingIter = 0;
numSubStatesArray = grammar.numSubStates;
trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false);
validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false);
maxLikelihood = calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees);
System.out.println("After "+opString+" in the " + (splitIndex/3+1) + "th round, we get a validation likelihood of " + maxLikelihood);
iter = 0;
//the inner loop: train the grammar via EM until validation likelihood reliably drops
do {
if (maxIterations>0){
iter += 1;
System.out.println("Beginning iteration "+(iter-1)+":");
// 1) Compute the validation likelihood of the previous iteration
System.out.print("Calculating validation likelihood...");
double validationLikelihood = calculateLogLikelihood(previousGrammar, previousLexicon, validationStateSetTrees); // The validation LL of previousGrammar/previousLexicon
System.out.println("done: "+validationLikelihood);
// 2) Perform the E step while computing the training likelihood of the previous iteration
System.out.print("Calculating training likelihood...");
grammar = new Grammar(grammar.numSubStates, grammar.findClosedPaths, grammar.smoother, grammar, grammar.threshold);
// lexicon = new SimpleLexicon(grammar.numSubStates, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, null, new NoSmoothing(), opts.unkThresh);
lexicon = maxLexicon.copyLexicon();
boolean updateOnlyLexicon = false;
double trainingLikelihood = doOneEStep(previousGrammar,previousLexicon,grammar,lexicon,trainStateSetTrees,updateOnlyLexicon,opts.rare); // The training LL of previousGrammar/previousLexicon
System.out.println("done: "+trainingLikelihood);
// 3) Perform the M-Step
lexicon.optimize(); // M Step
grammar.optimize(0); // M Step
// 4) Check whether previousGrammar/previousLexicon was in fact better than the best
if(iter<minIterations || validationLikelihood >= maxLikelihood) {
maxLikelihood = validationLikelihood;
maxGrammar = previousGrammar;
maxLexicon = previousLexicon;
droppingIter = 0;
} else { droppingIter++; }
// 5) advance the 'pointers'
previousGrammar = grammar;
previousLexicon = lexicon;
}
} while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter<maxIterations));
// Dump a grammar file to disk from time to time
ParserData pData = new ParserData(maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, 1, 0, binarization);
String outTmpName = outFileName + "_"+ (splitIndex/3+1)+"_"+opString+".gr";
System.out.println("Saving grammar to "+outTmpName+".");
if (pData.Save(outTmpName)) System.out.println("Saving successful.");
else System.out.println("Saving failed!");
pData = null;
}
// The last grammar/lexicon has not yet been evaluated. Even though the validation likelihood
// has been dropping in the past few iteration, there is still a chance that the last one was in
// fact the best so just in case we evaluate it.
System.out.print("Calculating last validation likelihood...");
double validationLikelihood = calculateLogLikelihood(grammar, lexicon, validationStateSetTrees);
System.out.println("done.\n Iteration "+iter+" (final) gives validation likelihood "+validationLikelihood);
if (validationLikelihood > maxLikelihood) {
maxLikelihood = validationLikelihood;
maxGrammar = previousGrammar;
maxLexicon = previousLexicon;
}
// System.out.println(lexicon);
// System.out.println(grammar);
ParserData pData = new ParserData(maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, opts.verticalMarkovization, opts.horizontalMarkovization, binarization);
System.out.println("Saving grammar to "+outFileName+".");
System.out.println("It gives a validation data log likelihood of: "+maxLikelihood);
if (pData.Save(outFileName)) System.out.println("Saving successful.");
else System.out.println("Saving failed!");
//System.exit(0);
}
/**
* @param opts
* @param outFileName
* @param lexicon
* @param grammar
* @param trainStateSetTrees
* @param regularize
* @param sigma
* @return
*/
private static ParsingObjectiveFunction newParsingObjectiveFunction(
Options opts, String outFileName, Linearizer linearizer, StateSetTreeList trainStateSetTrees, int regularize, double sigma) {
return /*opts.usePosteriorTraining? new PosteriorTrainingObjectiveFunction(linearizer, trainStateSetTrees, sigma,
regularize, opts.boostIncorrect, opts.cons, opts.nProcess, outFileName,
opts.doGEM, opts.doNOTprojectConstraints, opts.allPosteriorsWeight): */new ParsingObjectiveFunction(linearizer, trainStateSetTrees, sigma,
regularize, opts.cons, opts.nProcess, outFileName, opts.doNOTprojectConstraints, opts.connectedLexicon);
}
/**
* @param previousGrammar
* @param previousLexicon
* @param grammar
* @param lexicon
* @param trainStateSetTrees
* @return
*/
public static double doOneEStep(Grammar previousGrammar, Lexicon previousLexicon, Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees,
boolean updateOnlyLexicon, int unkThreshold) {
boolean secondHalf = false;
ArrayParser parser = new ArrayParser(previousGrammar,previousLexicon);
double trainingLikelihood = 0;
int n = 0;
int nTrees = trainStateSetTrees.size();
for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
secondHalf = (n++>nTrees/2.0);
boolean noSmoothing = true, debugOutput = false;
parser.doInsideOutsideScores(stateSetTree,noSmoothing,debugOutput); // E Step
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());//System.out.println(stateSetTree);
if ((Double.isInfinite(ll) || Double.isNaN(ll))) {
if (GrammarTrainer.VERBOSE){
System.out.println("Training sentence "+n+" is given "+ll+" log likelihood!");
System.out.println("Root iScore "+ stateSetTree.getLabel().getIScore(0)+" scale "+stateSetTree.getLabel().getIScale());
}
}
else {
lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf,noSmoothing,unkThreshold);
if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step
trainingLikelihood += ll; // there are for some reason some sentences that are unparsable
}
}
//SSIE
((SophisticatedLexicon) lexicon).overwriteWithMaxent();
return trainingLikelihood;
}
/**
* @param maxGrammar
* @param maxLexicon
* @param validationStateSetTrees
* @return
*/
public static double calculateLogLikelihood(Grammar maxGrammar, Lexicon maxLexicon, StateSetTreeList validationStateSetTrees) {
ArrayParser parser = new ArrayParser(maxGrammar, maxLexicon);
int unparsable = 0;
double maxLikelihood = 0;
for (Tree<StateSet> stateSetTree : validationStateSetTrees) {
parser.doInsideScores(stateSetTree,false,false,null); // Only inside scores are needed here
double ll = stateSetTree.getLabel().getIScore(0);
ll = Math.log(ll) + (100*stateSetTree.getLabel().getIScale());
if (Double.isInfinite(ll) || Double.isNaN(ll)) {
unparsable++;
//printBadLLReason(stateSetTree, lexicon);
}
else maxLikelihood += ll; // there are for some reason some sentences that are unparsable
}
// if (unparsable>0) System.out.print("Number of unparsable trees: "+unparsable+".");
return maxLikelihood;
}
/**
* @param stateSetTree
*/
public static void printBadLLReason(Tree<StateSet> stateSetTree, SophisticatedLexicon lexicon) {
System.out.println(stateSetTree.toString());
boolean lexiconProblem = false;
List<StateSet> words = stateSetTree.getYield();
Iterator<StateSet> wordIterator = words.iterator();
for (StateSet stateSet : stateSetTree.getPreTerminalYield()) {
String word = wordIterator.next().getWord();
boolean lexiconProblemHere = true;
for (int i = 0; i < stateSet.numSubStates(); i++) {
double score = stateSet.getIScore(i);
if (!(Double.isInfinite(score) || Double.isNaN(score))) {
lexiconProblemHere = false;
}
}
if (lexiconProblemHere) {
System.out.println("LEXICON PROBLEM ON STATE " + stateSet.getState()+" word "+word);
System.out.println(" word "+lexicon.wordCounter.getCount(stateSet.getWord()));
for (int i=0; i<stateSet.numSubStates(); i++) {
System.out.println(" tag "+lexicon.tagCounter[stateSet.getState()][i]);
System.out.println(" word/state/sub "+lexicon.wordToTagCounters[stateSet.getState()].get(stateSet.getWord())[i]);
}
}
lexiconProblem = lexiconProblem || lexiconProblemHere;
}
if (lexiconProblem)
System.out
.println(" the likelihood is bad because of the lexicon");
else
System.out
.println(" the likelihood is bad because of the grammar");
}
/**
* This function probably doesn't belong here, but because it should be called
* after {@link #updateStateSetTrees}, Leon left it here.
*
* @param trees Trees which have already had their inside-outside probabilities calculated,
* as by {@link #updateStateSetTrees}.
* @return The log likelihood of the trees.
*/
public static double logLikelihood(List<Tree<StateSet>> trees, boolean verbose) {
double likelihood = 0, l=0;
for (Tree<StateSet> tree : trees) {
l = tree.getLabel().getIScore(0);
if (verbose) System.out.println("LL is "+l+".");
if (Double.isInfinite(l) || Double.isNaN(l)){
System.out.println("LL is not finite.");
}
else {
likelihood += l;
}
}
return likelihood;
}
/**
* This updates the inside-outside probabilities for the list of trees using the parser's
* doInsideScores and doOutsideScores methods.
*
* @param trees A list of binarized, annotated StateSet Trees.
* @param parser The parser to score the trees.
*/
public static void updateStateSetTrees (List<Tree<StateSet>> trees, ArrayParser parser) {
for (Tree<StateSet> tree : trees) {
parser.doInsideOutsideScores(tree,false,false);
}
}
/**
* Convert a single Tree[String] to Tree[StateSet]
*
* @param tree
* @param numStates
* @param tagNumberer
* @return
*/
public static short[] initializeSubStateArray(List<Tree<String>> trainTrees,
List<Tree<String>> validationTrees, Numberer tagNumberer, short nSubStates){
// boolean dontSplitTags) {
// first generate unsplit grammar and lexicon
short[] nSub = new short[2];
nSub[0] = 1;
nSub[1] = nSubStates;
// do the validation set so that the numberer sees all tags and we can
// allocate big enough arrays
// note: although this variable is never read, this constructor adds the
// validation trees into the tagNumberer as a side effect, which is
// important
StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, nSub, true, tagNumberer);
@SuppressWarnings("unused")
StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, nSub, true, tagNumberer);
StateSetTreeList.initializeTagNumberer(trainTrees, tagNumberer);
StateSetTreeList.initializeTagNumberer(validationTrees, tagNumberer);
short numStates = (short)tagNumberer.total();
short[] nSubStateArray = new short[numStates];
Arrays.fill(nSubStateArray, nSubStates);
//System.out.println("Everything is split in two except for the root.");
nSubStateArray[0] = 1; // that's the ROOT
return nSubStateArray;
}
public static boolean[][][][][] loadDataNoZip(String fileName) {
boolean[][][][][] data = null;
try {
FileInputStream fis = new FileInputStream(fileName); // Load from file
// GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
ObjectInputStream in = new ObjectInputStream(fis); // Load objects
data = (boolean[][][][][])in.readObject(); // Read the mix of grammars
in.close(); // And close the stream.
} catch (IOException e) {
System.out.println("IOException\n"+e);
return null;
} catch (ClassNotFoundException e) {
System.out.println("Class not found!");
return null;
}
return data;
}
public static boolean saveDataNoZip(boolean[][][][][] data, String fileName){
try {
//here's some code from online; it looks good and gzips the output!
// there's a whole explanation at http://www.ecst.csuchico.edu/~amk/foo/advjava/notes/serial.html
// Create the necessary output streams to save the scribble.
FileOutputStream fos = new FileOutputStream(fileName); // Save to file
// GZIPOutputStream gzos = new GZIPOutputStream(fos); // Compressed
ObjectOutputStream out = new ObjectOutputStream(fos); // Save objects
out.writeObject(data); // Write the mix of grammars
out.flush(); // Always flush the output.
out.close(); // And close the stream.
} catch (IOException e) {
System.out.println("IOException: "+e);
return false;
}
return true;
}
private static final double TOL = 1e-5;
protected static boolean matches(double x, double y) {
return (Math.abs(x - y) / (Math.abs(x) + Math.abs(y) + 1e-10) < TOL);
}
}