package cc.mallet.fst; import java.util.ArrayList; import java.util.Collections; import cc.mallet.types.FeatureVectorSequence; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.Sequence; import cc.mallet.fst.TransducerTrainer.ByInstanceIncrements; /** * Trains CRF by stochastic gradient. Most effective on large training sets. * * @author kedarb */ public class CRFTrainerByStochasticGradient extends ByInstanceIncrements { protected CRF crf; // t is the decaying factor. lambda is some regularization depending on the // training set size and the gaussian prior. protected double learningRate, t, lambda; protected int iterationCount = 0; protected boolean converged = false; protected CRF.Factors expectations, constraints; public CRFTrainerByStochasticGradient(CRF crf, InstanceList trainingSample) { this.crf = crf; this.expectations = new CRF.Factors(crf); this.constraints = new CRF.Factors(crf); this.setLearningRateByLikelihood(trainingSample); } public CRFTrainerByStochasticGradient(CRF crf, double learningRate) { this.crf = crf; this.learningRate = learningRate; this.expectations = new CRF.Factors(crf); this.constraints = new CRF.Factors(crf); } public int getIteration() { return iterationCount; } public Transducer getTransducer() { return crf; } public boolean isFinishedTraining() { return converged; } // Best way to choose learning rate is to run training on a sample and set // it to the rate that produces maximum increase in likelihood or accuracy. // Then, to be conservative just halve the learning rate. // In general, eta = 1/(lambda*t) where // lambda=priorVariance*numTrainingInstances // After an initial eta_0 is set, t_0 = 1/(lambda*eta_0) // After each training step eta = 1/(lambda*(t+t_0)), t=0,1,2,..,Infinity /** Automatically sets the learning rate to one that would be good */ public void setLearningRateByLikelihood(InstanceList trainingSample) { int numIterations = 5; // was 10 -akm 1/25/08 double bestLearningRate = Double.NEGATIVE_INFINITY; double bestLikelihoodChange = Double.NEGATIVE_INFINITY; double currLearningRate = 5e-11; while (currLearningRate < 1) { currLearningRate *= 2; crf.parameters.zero(); double beforeLikelihood = computeLikelihood(trainingSample); double likelihoodChange = trainSample(trainingSample, numIterations, currLearningRate) - beforeLikelihood; System.out.println("likelihood change = " + likelihoodChange + " for learningrate=" + currLearningRate); if (likelihoodChange > bestLikelihoodChange) { bestLikelihoodChange = likelihoodChange; bestLearningRate = currLearningRate; } } // reset the parameters crf.parameters.zero(); // conservative estimate for learning rate bestLearningRate /= 2; System.out.println("Setting learning rate to " + bestLearningRate); setLearningRate(bestLearningRate); } private double trainSample(InstanceList trainingSample, int numIterations, double rate) { double lambda = trainingSample.size(); double t = 1 / (lambda * rate); double loglik = Double.NEGATIVE_INFINITY; for (int i = 0; i < numIterations; i++) { loglik = 0.0; for (int j = 0; j < trainingSample.size(); j++) { rate = 1 / (lambda * t); loglik += trainIncrementalLikelihood(trainingSample.get(j), rate); t += 1.0; } } return loglik; } private double computeLikelihood(InstanceList trainingSample) { double loglik = 0.0; for (int i = 0; i < trainingSample.size(); i++) { Instance trainingInstance = trainingSample.get(i); FeatureVectorSequence fvs = (FeatureVectorSequence) trainingInstance .getData(); Sequence labelSequence = (Sequence) trainingInstance.getTarget(); loglik += new SumLatticeDefault(crf, fvs, labelSequence, null) .getTotalWeight(); loglik -= new SumLatticeDefault(crf, fvs, null, null) .getTotalWeight(); } constraints.zero(); expectations.zero(); return loglik; } public void setLearningRate(double r) { this.learningRate = r; } public double getLearningRate() { return this.learningRate; } public boolean train(InstanceList trainingSet, int numIterations) { return train(trainingSet, numIterations, 1); } public boolean train(InstanceList trainingSet, int numIterations, int numIterationsBetweenEvaluation) { assert (expectations.structureMatches(crf.parameters)); assert (constraints.structureMatches(crf.parameters)); lambda = 1.0 / trainingSet.size(); t = 1.0 / (lambda * learningRate); converged = false; ArrayList<Integer> trainingIndices = new ArrayList<Integer>(); for (int i = 0; i < trainingSet.size(); i++) trainingIndices.add(i); double oldLoglik = Double.NEGATIVE_INFINITY; while (numIterations-- > 0) { iterationCount++; // shuffle the indices Collections.shuffle(trainingIndices); double loglik = 0.0; for (int i = 0; i < trainingSet.size(); i++) { learningRate = 1.0 / (lambda * t); loglik += trainIncrementalLikelihood(trainingSet .get(trainingIndices.get(i))); t += 1.0; } System.out.println("loglikelihood[" + numIterations + "] = " + loglik); if (Math.abs(loglik - oldLoglik) < 1e-3) { converged = true; break; } oldLoglik = loglik; Runtime.getRuntime().gc(); if (iterationCount % numIterationsBetweenEvaluation == 0) runEvaluators(); } return converged; } // TODO Add some way to train by batches of instances, where the batch // memberships are determined externally? Or provide some easy interface for // creating batches. public boolean trainIncremental(InstanceList trainingSet) { this.train(trainingSet, 1); return false; } public boolean trainIncremental(Instance trainingInstance) { assert (expectations.structureMatches(crf.parameters)); trainIncrementalLikelihood(trainingInstance); return false; } /** * Adjust the parameters by default learning rate according to the gradient * of this single Instance, and return the true label sequence likelihood. */ public double trainIncrementalLikelihood(Instance trainingInstance) { return trainIncrementalLikelihood(trainingInstance, learningRate); } /** * Adjust the parameters by learning rate according to the gradient of this * single Instance, and return the true label sequence likelihood. */ public double trainIncrementalLikelihood(Instance trainingInstance, double rate) { double singleLoglik; constraints.zero(); expectations.zero(); FeatureVectorSequence fvs = (FeatureVectorSequence) trainingInstance .getData(); Sequence labelSequence = (Sequence) trainingInstance.getTarget(); singleLoglik = new SumLatticeDefault(crf, fvs, labelSequence, constraints.new Incrementor()).getTotalWeight(); singleLoglik -= new SumLatticeDefault(crf, fvs, null, expectations.new Incrementor()).getTotalWeight(); // Calculate parameter gradient given these instances: (constraints - // expectations) constraints.plusEquals(expectations, -1); // Change the parameters a little by this difference, obeying // weightsFrozen crf.parameters.plusEquals(constraints, rate, true); return singleLoglik; } }