/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package opennlp.maxent; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import opennlp.model.DataIndexer; import opennlp.model.EvalParameters; import opennlp.model.EventStream; import opennlp.model.MutableContext; import opennlp.model.OnePassDataIndexer; import opennlp.model.Prior; import opennlp.model.UniformPrior; /** * An implementation of Generalized Iterative Scaling. The reference paper * for this implementation was Adwait Ratnaparkhi's tech report at the * University of Pennsylvania's Institute for Research in Cognitive Science, * and is available at <a href ="ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z"><code>ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z</code></a>. * * The slack parameter used in the above implementation has been removed by default * from the computation and a method for updating with Gaussian smoothing has been * added per Investigating GIS and Smoothing for Maximum Entropy Taggers, Clark and Curran (2002). * <a href="http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf"><code>http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf</code></a> * The slack parameter can be used by setting <code>useSlackParameter</code> to true. * Gaussian smoothing can be used by setting <code>useGaussianSmoothing</code> to true. * * A prior can be used to train models which converge to the distribution which minimizes the * relative entropy between the distribution specified by the empirical constraints of the training * data and the specified prior. By default, the uniform distribution is used as the prior. */ public class GISTrainer { private static final Log LOG = LogFactory.getLog(GISTrainer.class); private String currentMessage = ""; /** * Specifies whether unseen context/outcome pairs should be estimated as occur very infrequently. */ private boolean useSimpleSmoothing = false; /** * Specified whether parameter updates should prefer a distribution of parameters which * is gaussian. */ private boolean useGaussianSmoothing = false; private double sigma = 2.0; // If we are using smoothing, this is used as the "number" of // times we want the trainer to imagine that it saw a feature that it // actually didn't see. Defaulted to 0.1. private double _smoothingObservation = 0.1; private final boolean printMessages; /** * Number of unique events which occured in the event set. */ private int numUniqueEvents; /** * Number of predicates. */ private int numPreds; /** * Number of outcomes. */ private int numOutcomes; /** * Records the array of predicates seen in each event. */ private int[][] contexts; /** * The value associated with each context. If null then context values are assumes to be 1. */ private float[][] values; /** * List of outcomes for each event i, in context[i]. */ private int[] outcomeList; /** * Records the num of times an event has been seen for each event i, in context[i]. */ private int[] numTimesEventsSeen; /** * The number of times a predicate occured in the training data. */ private int[] predicateCounts; private int cutoff; /** * Stores the String names of the outcomes. The GIS only tracks outcomes as * ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] outcomeLabels; /** * Stores the String names of the predicates. The GIS only tracks predicates * as ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] predLabels; /** * Stores the observed expected values of the features based on training data. */ private MutableContext[] observedExpects; /** * Stores the estimated parameter value of each predicate during iteration */ private MutableContext[] params; /** * Stores the expected values of the features based on the current models */ private MutableContext[][] modelExpects; /** * This is the prior distribution that the model uses for training. */ private Prior prior; private static final double LLThreshold = 0.0001; /** * Initial probability for all outcomes. */ private EvalParameters evalParams; /** * Creates a new <code>GISTrainer</code> instance which does not print * progress messages about training to STDOUT. * */ public GISTrainer() { printMessages = false; } /** * Creates a new <code>GISTrainer</code> instance. * * @param printMessages sends progress messages about training to * STDOUT when true; trains silently otherwise. */ public GISTrainer(boolean printMessages) { this.printMessages = printMessages; } /** * Sets whether this trainer will use smoothing while training the model. * This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param smooth true if smoothing is desired, false if not */ public void setSmoothing(boolean smooth) { useSimpleSmoothing = smooth; } /** * Sets whether this trainer will use smoothing while training the model. * This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param timesSeen the "number" of times we want the trainer to imagine * it saw a feature that it actually didn't see */ public void setSmoothingObservation(double timesSeen) { _smoothingObservation = timesSeen; } /** * Sets whether this trainer will use smoothing while training the model. * This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param smooth true if smoothing is desired, false if not */ public void setGaussianSigma(double sigmaValue) { useGaussianSmoothing = true; sigma = sigmaValue; } /** * Trains a GIS model on the event in the specified event stream, using the specified number * of iterations and the specified count cutoff. * @param eventStream A stream of all events. * @param iterations The number of iterations to use for GIS. * @param cutoff The number of times a feature must occur to be included. * @return A GIS model trained with specified */ public GISModel trainModel(EventStream eventStream, int iterations, int cutoff) throws IOException { return trainModel(iterations, new OnePassDataIndexer(eventStream,cutoff),cutoff); } /** * Train a model using the GIS algorithm. * * @param iterations The number of GIS iterations to perform. * @param di The data indexer used to compress events in memory. * @return The newly trained model, which can be used immediately or saved * to disk using an opennlp.maxent.io.GISModelWriter object. */ public GISModel trainModel(int iterations, DataIndexer di, int cutoff) { return trainModel(iterations,di,new UniformPrior(),cutoff,1); } /** * Train a model using the GIS algorithm. * * @param iterations The number of GIS iterations to perform. * @param di The data indexer used to compress events in memory. * @param modelPrior The prior distribution used to train this model. * @return The newly trained model, which can be used immediately or saved * to disk using an opennlp.maxent.io.GISModelWriter object. */ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) { if (threads <= 0) throw new IllegalArgumentException("threads must be at leat one or greater!"); modelExpects = new MutableContext[threads][]; /************** Incorporate all of the needed info ******************/ display("Incorporating indexed data for training... \n"); contexts = di.getContexts(); values = di.getValues(); this.cutoff = cutoff; predicateCounts = di.getPredCounts(); numTimesEventsSeen = di.getNumTimesEventsSeen(); numUniqueEvents = contexts.length; this.prior = modelPrior; //printTable(contexts); // determine the correction constant and its inverse double correctionConstant = 0; for (int ci = 0; ci < contexts.length; ci++) { if (values == null || values[ci] == null) { if (contexts[ci].length > correctionConstant) { correctionConstant = contexts[ci].length; } } else { float cl = values[ci][0]; for (int vi=1;vi<values[ci].length;vi++) { cl+=values[ci][vi]; } if (cl > correctionConstant) { correctionConstant = cl; } } } display("done.\n"); outcomeLabels = di.getOutcomeLabels(); outcomeList = di.getOutcomeList(); numOutcomes = outcomeLabels.length; predLabels = di.getPredLabels(); prior.setLabels(outcomeLabels,predLabels); numPreds = predLabels.length; display("\tNumber of Event Tokens: " + numUniqueEvents + "\n"); display("\t Number of Outcomes: " + numOutcomes + "\n"); display("\t Number of Predicates: " + numPreds + "\n"); // set up feature arrays float[][] predCount = new float[numPreds][numOutcomes]; for (int ti = 0; ti < numUniqueEvents; ti++) { for (int j = 0; j < contexts[ti].length; j++) { if (values != null && values[ti] != null) { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]*values[ti][j]; } else { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]; } } } //printTable(predCount); di = null; // don't need it anymore // A fake "observation" to cover features which are not detected in // the data. The default is to assume that we observed "1/10th" of a // feature during training. final double smoothingObservation = _smoothingObservation; // Get the observed expectations of the features. Strictly speaking, // we should divide the counts by the number of Tokens, but because of // the way the model's expectations are approximated in the // implementation, this is cancelled out when we compute the next // iteration of a parameter, making the extra divisions wasteful. params = new MutableContext[numPreds]; for (int i = 0; i< modelExpects.length; i++) modelExpects[i] = new MutableContext[numPreds]; observedExpects = new MutableContext[numPreds]; // The model does need the correction constant and the correction feature. The correction constant // is only needed during training, and the correction feature is not necessary. // For compatibility reasons the model contains form now on a correction constant of 1, // and a correction param 0. evalParams = new EvalParameters(params,0,1,numOutcomes); int[] activeOutcomes = new int[numOutcomes]; int[] outcomePattern; int[] allOutcomesPattern= new int[numOutcomes]; for (int oi = 0; oi < numOutcomes; oi++) { allOutcomesPattern[oi] = oi; } int numActiveOutcomes = 0; for (int pi = 0; pi < numPreds; pi++) { numActiveOutcomes = 0; if (useSimpleSmoothing) { numActiveOutcomes = numOutcomes; outcomePattern = allOutcomesPattern; } else { //determine active outcomes for (int oi = 0; oi < numOutcomes; oi++) { if (predCount[pi][oi] > 0 && predicateCounts[pi] >= cutoff) { activeOutcomes[numActiveOutcomes] = oi; numActiveOutcomes++; } } if (numActiveOutcomes == numOutcomes) { outcomePattern = allOutcomesPattern; } else { outcomePattern = new int[numActiveOutcomes]; for (int aoi=0;aoi<numActiveOutcomes;aoi++) { outcomePattern[aoi] = activeOutcomes[aoi]; } } } params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); for (int i = 0; i< modelExpects.length; i++) modelExpects[i][pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); for (int aoi=0;aoi<numActiveOutcomes;aoi++) { int oi = outcomePattern[aoi]; params[pi].setParameter(aoi, 0.0); for (int i = 0; i< modelExpects.length; i++) modelExpects[i][pi].setParameter(aoi, 0.0); if (predCount[pi][oi] > 0) { observedExpects[pi].setParameter(aoi, predCount[pi][oi]); } else if (useSimpleSmoothing) { observedExpects[pi].setParameter(aoi,smoothingObservation); } } } predCount = null; // don't need it anymore display("...done.\n"); /***************** Find the parameters ************************/ if (threads == 1) display("Computing model parameters ...\n"); else display("Computing model parameters in " + threads +" threads...\n"); findParameters(iterations, correctionConstant); /*************** Create and return the model ******************/ // To be compatible with old models the correction constant is always 1 return new GISModel(params, predLabels, outcomeLabels, 1, evalParams.getCorrectionParam()); } /* Estimate and return the model parameters. */ private void findParameters(int iterations, double correctionConstant) { double prevLL = 0.0; double currLL = 0.0; display("Performing " + iterations + " iterations.\n"); for (int i = 1; i <= iterations; i++) { if (i < 10) display(" " + i + ": "); else if (i < 100) display(" " + i + ": "); else display(i + ": "); currLL = nextIteration(correctionConstant); if (i > 1) { if (prevLL > currLL) { LOG.error("Model Diverging: loglikelihood decreased"); break; } if (currLL - prevLL < LLThreshold) { break; } } prevLL = currLL; } // kill a bunch of these big objects now that we don't need them observedExpects = null; modelExpects = null; numTimesEventsSeen = null; contexts = null; } //modeled on implementation in Zhang Le's maxent kit private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) { double param = params[predicate].getParameters()[oid]; double x0 = 0.0; double modelValue = modelExpects[0][predicate].getParameters()[oid]; double observedValue = observedExpects[predicate].getParameters()[oid]; for (int i = 0; i < 50; i++) { double tmp = modelValue * Math.exp(correctionConstant * x0); double f = tmp + (param + x0) / sigma - observedValue; double fp = tmp * correctionConstant + 1 / sigma; if (fp == 0) { break; } double x = x0 - f / fp; if (Math.abs(x - x0) < 0.000001) { x0 = x; break; } x0 = x; } return x0; } private class ModelExpactationComputeTask implements Callable<ModelExpactationComputeTask> { private final int startIndex; private final int length; private double loglikelihood = 0; private int numEvents = 0; private int numCorrect = 0; final private int threadIndex; // startIndex to compute, number of events to compute ModelExpactationComputeTask(int threadIndex, int startIndex, int length) { this.startIndex = startIndex; this.length = length; this.threadIndex = threadIndex; } public ModelExpactationComputeTask call() { final double[] modelDistribution = new double[numOutcomes]; for (int ei = startIndex; ei < startIndex + length; ei++) { // TODO: check interruption status here, if interrupted set a poisoned flag and return if (values != null) { prior.logPrior(modelDistribution, contexts[ei], values[ei]); GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams); } else { prior.logPrior(modelDistribution,contexts[ei]); GISModel.eval(contexts[ei], modelDistribution, evalParams); } for (int j = 0; j < contexts[ei].length; j++) { int pi = contexts[ei][j]; if (predicateCounts[pi] >= cutoff) { int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes(); for (int aoi=0;aoi<activeOutcomes.length;aoi++) { int oi = activeOutcomes[aoi]; // numTimesEventsSeen must also be thread safe if (values != null && values[ei] != null) { modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]); } else { modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]); } } } } loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei]; numEvents += numTimesEventsSeen[ei]; if (printMessages) { int max = 0; for (int oi = 1; oi < numOutcomes; oi++) { if (modelDistribution[oi] > modelDistribution[max]) { max = oi; } } if (max == outcomeList[ei]) { numCorrect += numTimesEventsSeen[ei]; } } } return this; } synchronized int getNumEvents() { return numEvents; } synchronized int getNumCorrect() { return numCorrect; } synchronized double getLoglikelihood() { return loglikelihood; } } /* Compute one iteration of GIS and retutn log-likelihood.*/ private double nextIteration(double correctionConstant) { // compute contribution of p(a|b_i) for each feature and the new // correction parameter double loglikelihood = 0.0; int numEvents = 0; int numCorrect = 0; int numberOfThreads = modelExpects.length; ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads); int taskSize = numUniqueEvents / numberOfThreads; int leftOver = numUniqueEvents % numberOfThreads; List<Future<?>> futures = new ArrayList<Future<?>>(); for (int i = 0; i < numberOfThreads; i++) { if (i != numberOfThreads - 1) futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize))); else futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver))); } for (Future<?> future : futures) { ModelExpactationComputeTask finishedTask = null; try { finishedTask = (ModelExpactationComputeTask) future.get(); } catch (InterruptedException e) { // TODO: We got interrupted, but that is currently not really supported! // For now we just print the exception and fail hard. We hopefully soon // handle this case properly! e.printStackTrace(); throw new IllegalStateException("Interruption is not supported!", e); } catch (ExecutionException e) { // Only runtime exception can be thrown during training, if one was thrown // it should be re-thrown. That could for example be a NullPointerException // which is caused through a bug in our implementation. throw new RuntimeException(e.getCause()); } // When they are done, retrieve the results ... numEvents += finishedTask.getNumEvents(); numCorrect += finishedTask.getNumCorrect(); loglikelihood += finishedTask.getLoglikelihood(); } executor.shutdown(); display("."); // merge the results of the two computations for (int pi = 0; pi < numPreds; pi++) { int[] activeOutcomes = params[pi].getOutcomes(); for (int aoi=0;aoi<activeOutcomes.length;aoi++) { for (int i = 1; i < modelExpects.length; i++) { modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]); } } } display("."); // compute the new parameter values for (int pi = 0; pi < numPreds; pi++) { double[] observed = observedExpects[pi].getParameters(); double[] model = modelExpects[0][pi].getParameters(); int[] activeOutcomes = params[pi].getOutcomes(); for (int aoi=0;aoi<activeOutcomes.length;aoi++) { if (useGaussianSmoothing) { params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,correctionConstant)); } else { if (model[aoi] == 0) { LOG.error("Model expects == 0 for "+predLabels[pi]+" "+outcomeLabels[aoi]); } //params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi]))); params[pi].updateParameter(aoi,((Math.log(observed[aoi]) - Math.log(model[aoi]))/correctionConstant)); } for (int i = 0; i< modelExpects.length; i++) modelExpects[i][pi].setParameter(aoi,0.0); // re-initialize to 0.0's } } display(". loglikelihood=" + loglikelihood + "\t" + ((double) numCorrect / numEvents) + "\n"); return loglikelihood; } private void display(String s) { if (printMessages) { currentMessage += s; if (s.endsWith("\n")) { LOG.debug(currentMessage.substring(0, currentMessage.length()-1)); currentMessage = ""; } } } }