package edu.stanford.nlp.sentiment; import edu.stanford.nlp.util.logging.Redwood; import java.io.File; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.Arrays; import java.util.Collections; import java.util.List; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.Timing; public class SentimentTraining { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(SentimentTraining.class); private static final NumberFormat NF = new DecimalFormat("0.00"); private static final NumberFormat FILENAME = new DecimalFormat("0000"); private SentimentTraining() {} // static methods public static void executeOneTrainingBatch(SentimentModel model, List<Tree> trainingBatch, double[] sumGradSquare) { SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trainingBatch); double[] theta = model.paramsToVector(); // AdaGrad double eps = 1e-3; // TODO: do we want to iterate multiple times per batch? double[] gradf = gcFunc.derivativeAt(theta); double currCost = gcFunc.valueAt(theta); log.info("batch cost: " + currCost); for (int feature = 0; feature<gradf.length; feature++ ) { sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature]*gradf[feature]; theta[feature] = theta[feature] - (model.op.trainOptions.learningRate * gradf[feature]/(Math.sqrt(sumGradSquare[feature])+eps)); } model.vectorToParams(theta); } public static void train(SentimentModel model, String modelPath, List<Tree> trainingTrees, List<Tree> devTrees) { Timing timing = new Timing(); long maxTrainTimeMillis = model.op.trainOptions.maxTrainTimeSeconds * 1000; int debugCycle = 0; // double bestAccuracy = 0.0; // train using AdaGrad (seemed to work best during the dvparser project) double[] sumGradSquare = new double[model.totalParamSize()]; Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight); int numBatches = trainingTrees.size() / model.op.trainOptions.batchSize + 1; log.info("Training on " + trainingTrees.size() + " trees in " + numBatches + " batches"); log.info("Times through each training batch: " + model.op.trainOptions.epochs); for (int epoch = 0; epoch < model.op.trainOptions.epochs; ++epoch) { log.info("======================================"); log.info("Starting epoch " + epoch); if (epoch > 0 && model.op.trainOptions.adagradResetFrequency > 0 && (epoch % model.op.trainOptions.adagradResetFrequency == 0)) { log.info("Resetting adagrad weights to " + model.op.trainOptions.initialAdagradWeight); Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight); } List<Tree> shuffledSentences = Generics.newArrayList(trainingTrees); if (model.op.trainOptions.shuffleMatrices) { Collections.shuffle(shuffledSentences, model.rand); } for (int batch = 0; batch < numBatches; ++batch) { log.info("======================================"); log.info("Epoch " + epoch + " batch " + batch); // Each batch will be of the specified batch size, except the // last batch will include any leftover trees at the end of // the list int startTree = batch * model.op.trainOptions.batchSize; int endTree = (batch + 1) * model.op.trainOptions.batchSize; if (endTree > shuffledSentences.size()) { endTree = shuffledSentences.size(); } executeOneTrainingBatch(model, shuffledSentences.subList(startTree, endTree), sumGradSquare); long totalElapsed = timing.report(); log.info("Finished epoch " + epoch + " batch " + batch + "; total training time " + totalElapsed + " ms"); if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) { // no need to debug output, we're done now break; } if (batch == (numBatches - 1) && model.op.trainOptions.debugOutputEpochs > 0 && (epoch + 1) % model.op.trainOptions.debugOutputEpochs == 0) { double score = 0.0; if (devTrees != null) { Evaluate eval = new Evaluate(model); eval.eval(devTrees); eval.printSummary(); score = eval.exactNodeAccuracy() * 100.0; } // output an intermediate model if (modelPath != null) { String tempPath; if (modelPath.endsWith(".ser.gz")) { tempPath = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".ser.gz"; } else if (modelPath.endsWith(".gz")) { tempPath = modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".gz"; } else { tempPath = modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score); } model.saveSerialized(tempPath); } ++debugCycle; } } long totalElapsed = timing.report(); if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) { log.info("Max training time exceeded, exiting"); break; } } } public static boolean runGradientCheck(SentimentModel model, List<Tree> trees) { SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trees); return gcFunc.gradientCheck(model.totalParamSize(), 50, model.paramsToVector()); } /** Trains a sentiment model. * The -trainPath argument points to a labeled sentiment treebank. * The trees in this data will be used to train the model parameters (also to seed the model vocabulary). * The -devPath argument points to a second labeled sentiment treebank. * The trees in this data will be used to periodically evaluate the performance of the model. * We won't train on this data; it will only be used to test how well the model generalizes to unseen data. * The -model argument specifies where to save the learned sentiment model. * * @param args Command line arguments */ public static void main(String[] args) { RNNOptions op = new RNNOptions(); String trainPath = "sentimentTreesDebug.txt"; String devPath = null; boolean runGradientCheck = false; boolean runTraining = false; boolean filterUnknown = false; String modelPath = null; for (int argIndex = 0; argIndex < args.length; ) { if (args[argIndex].equalsIgnoreCase("-train")) { runTraining = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-gradientcheck")) { runGradientCheck = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-trainpath")) { trainPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-devpath")) { devPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-filterUnknown")) { filterUnknown = true; argIndex++; } else { int newArgIndex = op.setOption(args, argIndex); if (newArgIndex == argIndex) { throw new IllegalArgumentException("Unknown argument " + args[argIndex]); } argIndex = newArgIndex; } } // read in the trees List<Tree> trainingTrees = SentimentUtils.readTreesWithGoldLabels(trainPath); log.info("Read in " + trainingTrees.size() + " training trees"); if (filterUnknown) { trainingTrees = SentimentUtils.filterUnknownRoots(trainingTrees); log.info("Filtered training trees: " + trainingTrees.size()); } List<Tree> devTrees = null; if (devPath != null) { devTrees = SentimentUtils.readTreesWithGoldLabels(devPath); log.info("Read in " + devTrees.size() + " dev trees"); if (filterUnknown) { devTrees = SentimentUtils.filterUnknownRoots(devTrees); log.info("Filtered dev trees: " + devTrees.size()); } } // TODO: binarize the trees, then collapse the unary chains. // Collapsed unary chains always have the label of the top node in // the chain // Note: the sentiment training data already has this done. // However, when we handle trees given to us from the Stanford Parser, // we will have to perform this step // build an uninitialized SentimentModel from the binary productions log.info("Sentiment model options:\n" + op); SentimentModel model = new SentimentModel(op, trainingTrees); if (op.trainOptions.initialMatrixLogPath != null) { StringUtils.printToFile(new File(op.trainOptions.initialMatrixLogPath), model.toString(), false, false, "utf-8"); } // TODO: need to handle unk rules somehow... at test time the tree // structures might have something that we never saw at training // time. for example, we could put a threshold on all of the // rules at training time and anything that doesn't meet that // threshold goes into the unk. perhaps we could also use some // component of the accepted training rules to build up the "unk" // parameter in case there are no rules that don't meet the // threshold if (runGradientCheck) { runGradientCheck(model, trainingTrees); } if (runTraining) { train(model, modelPath, trainingTrees, devTrees); model.saveSerialized(modelPath); } } }