package cc.mallet.fst.semi_supervised; import java.util.ArrayList; import cc.mallet.fst.CRF; import cc.mallet.fst.CRFCacheStaleIndicator; import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood; import cc.mallet.fst.CRFOptimizableByGradientValues; import cc.mallet.fst.CRFOptimizableByLabelLikelihood; import cc.mallet.fst.CRFTrainerByLabelLikelihood; import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood; import cc.mallet.fst.ThreadedOptimizable; import cc.mallet.fst.Transducer; import cc.mallet.fst.TransducerTrainer; import cc.mallet.fst.semi_supervised.CRFOptimizableByGE; import cc.mallet.fst.semi_supervised.StateLabelMap; import cc.mallet.fst.semi_supervised.constraints.GEConstraint; import cc.mallet.optimize.LimitedMemoryBFGS; import cc.mallet.optimize.Optimizable; import cc.mallet.types.InstanceList; public class CRFTrainerByLikelihoodAndGE extends TransducerTrainer { private boolean initSupervised; private boolean converged; private double geWeight; private double gpv; private int supIterations; private int numThreads; private int iteration; private CRF crf; private ArrayList<GEConstraint> constraints; private StateLabelMap map; public CRFTrainerByLikelihoodAndGE(CRF crf, ArrayList<GEConstraint> constraints, StateLabelMap map) { this.crf = crf; this.constraints = constraints; this.map = map; this.iteration = 0; this.converged = false; this.geWeight = 1.0; this.initSupervised = false; this.gpv = 10.0; this.numThreads = 1; this.supIterations = Integer.MAX_VALUE; } public void setGEWeight(double weight) { this.geWeight = weight; } public void setGaussianPriorVariance(double gpv) { this.gpv = gpv; } public void setInitSupervised(boolean flag) { this.initSupervised = flag; } public void setSupervisedIterations(int iterations) { this.supIterations = iterations; } public void setNumThreads(int numThreads) { this.numThreads = numThreads; } @Override public Transducer getTransducer() { return crf; } @Override public int getIteration() { return iteration; } @Override public boolean isFinishedTraining() { return converged; } public boolean train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) { System.err.println(trainingSet.size()); System.err.println(unlabeledSet.size()); if (initSupervised) { // train supervised if (numThreads == 1) { CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf); trainer.setAddNoFactors(true); trainer.setGaussianPriorVariance(gpv); trainer.train(trainingSet,supIterations); } else { CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf,numThreads); trainer.setAddNoFactors(true); trainer.setGaussianPriorVariance(gpv); trainer.train(trainingSet,supIterations); trainer.shutdown(); } runEvaluators(); } // train semi-supervised Optimizable.ByGradientValue optLikelihood; if (numThreads == 1) { optLikelihood = new CRFOptimizableByLabelLikelihood(crf,trainingSet); ((CRFOptimizableByLabelLikelihood)optLikelihood).setGaussianPriorVariance(gpv); } else { CRFOptimizableByBatchLabelLikelihood likelihood = new CRFOptimizableByBatchLabelLikelihood(crf,trainingSet,numThreads); optLikelihood = new ThreadedOptimizable(likelihood,trainingSet,crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf)); likelihood.setGaussianPriorVariance(gpv); } CRFOptimizableByGE ge = new CRFOptimizableByGE(crf,constraints,unlabeledSet,map,numThreads,geWeight); // turn off the prior... it already appears above! ge.setGaussianPriorVariance(Double.POSITIVE_INFINITY); CRFOptimizableByGradientValues opt = new CRFOptimizableByGradientValues(crf,new Optimizable.ByGradientValue[] { optLikelihood, ge }); LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(opt); try { converged = optimizer.optimize(numIterations); } catch (Exception e) { e.printStackTrace(); } optimizer.reset(); try { converged = optimizer.optimize(numIterations); } catch (Exception e) { e.printStackTrace(); } if (numThreads > 1) { ((ThreadedOptimizable)optLikelihood).shutdown(); ge.shutdown(); } return converged; } @Override public boolean train(InstanceList trainingSet, int numIterations) { throw new RuntimeException("Must use train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) instead"); } }