/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr; import java.util.ArrayList; import java.util.BitSet; import cc.mallet.fst.CRF; import cc.mallet.fst.Transducer; import cc.mallet.fst.TransducerTrainer; import cc.mallet.fst.semi_supervised.StateLabelMap; import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint; import cc.mallet.optimize.LimitedMemoryBFGS; import cc.mallet.optimize.Optimizer; import cc.mallet.types.InstanceList; /** * Posterior regularization trainer. * * @author Gregory Druck */ public class CRFTrainerByPR extends TransducerTrainer implements TransducerTrainer.ByOptimization { private boolean converged; private int iter; private int numThreads; private double pGpv; private double tolerance; private double value; private double qValue; private ArrayList<PRConstraint> constraints; private LimitedMemoryBFGS bfgs; private CRF crf; private StateLabelMap stateLabelMap; public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints) { this(crf,constraints,1); } public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints, int numThreads) { this.crf = crf; this.iter = 0; this.value = Double.NEGATIVE_INFINITY; this.constraints = constraints; this.pGpv = 10; this.tolerance = 0.001; this.numThreads = numThreads; this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(),true); } @Override public int getIteration() { return iter; } @Override public Transducer getTransducer() { return crf; } @Override public boolean isFinishedTraining() { return converged; } // map between states in CRF FST and labels public void setStateLabelMap(StateLabelMap map) { this.stateLabelMap = map; } public void setPGaussianPriorVariance(double pGpv) { this.pGpv = pGpv; } public void setTolerance(double tolerance) { this.tolerance = tolerance; } @Override public boolean train(InstanceList train, int numIterations) { return train(train,0,numIterations); } public boolean train(InstanceList train, int minIter, int maxIter) { return train(train,minIter,maxIter,Integer.MAX_VALUE); } public boolean train(InstanceList train, int minIter, int maxIter, int maxIterPerStep) { double oldValue = 0; int max = iter + maxIter; BitSet constrainedInstances = new BitSet(); for (PRConstraint constraint : constraints) { constrainedInstances.or(constraint.preProcess(train)); constraint.setStateLabelMap(stateLabelMap); } int removed = 0; InstanceList tempTrain = train.cloneEmpty(); for (int ii = 0; ii < train.size(); ii++) { if (constrainedInstances.get(ii)) { tempTrain.add(train.get(ii)); } else { removed++; } } train = tempTrain; System.err.println("Removed " + removed + " instances that do not contain constraints."); PRAuxiliaryModel model = new PRAuxiliaryModel(crf,constraints); for (; iter < max; iter++) { long startTime = System.currentTimeMillis(); // train q ConstraintsOptimizableByPR opt = new ConstraintsOptimizableByPR(crf, train, model, numThreads); bfgs = new LimitedMemoryBFGS(opt); try { bfgs.optimize(maxIterPerStep); } catch (Exception e) { e.printStackTrace(); } opt.shutdown(); /* for (int j = 0; j < constraints.size(); j++) { constraints.get(j).print(); } */ qValue = opt.getCompleteValueContribution(); assert(qValue > 0); // use to train p CRFOptimizableByKL optP = new CRFOptimizableByKL(crf, train, model, opt.getCachedDots(), numThreads, 1); optP.setGaussianPriorVariance(pGpv); LimitedMemoryBFGS bfgsP = new LimitedMemoryBFGS(optP); try { bfgsP.optimize(maxIterPerStep); } catch (Exception e) { e.printStackTrace(); } optP.shutdown(); value = optP.getValue() - qValue; assert(value < 0); System.err.println("Total value = " + value + " (pValue = " + optP.getValue() + ") (qValue = " + (-qValue) + ")"); System.err.println("Time for iteration " + String.format("%.2f",((System.currentTimeMillis() - startTime) / 1000.)) + "s"); // stopping criteria from BFGS //System.err.println("Convergence test: " + (2.0*Math.abs(value-oldValue)) + " <= " + (tolerance * (Math.abs(value)+Math.abs(oldValue) + 1e-5))); if((iter >= minIter) && 2.0*Math.abs(value-oldValue) <= tolerance * (Math.abs(value)+Math.abs(oldValue) + 1e-5)){ System.err.println("AP value difference below tolerance (oldValue: " + oldValue + "newValue: " + value); break; } oldValue = value; runEvaluators(); } converged = true; return converged; } public double getTotalValue() { return value; } public double getQValue() { return qValue; } public Optimizer getOptimizer() { return bfgs; } }