// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla; import hmmla.decode.CoarseToFineDecoder; import hmmla.decode.Decoder; import hmmla.eval.Eval; import hmmla.eval.Result; import hmmla.hmm.HmmTrainer; import hmmla.hmm.HmmTrainerFactory; import hmmla.hmm.Model; import hmmla.hmm.Smoother; import hmmla.hmm.SmootherFactory; import hmmla.io.PosReader; import hmmla.io.Sentence; import hmmla.splitmerge.ApproximativeLossEstimator; import hmmla.splitmerge.ConcurrentEmTrainer; import hmmla.splitmerge.EmTrainer; import hmmla.splitmerge.ExactLossEstimator; import hmmla.splitmerge.LossEstimator; import hmmla.splitmerge.Merger; import hmmla.splitmerge.SimpleEmTrainer; import hmmla.splitmerge.Splitter; import hmmla.util.BufferedIterable; import hmmla.util.Ling; import hmmla.util.Mapping; import hmmla.util.RandomIterable; import hmmla.util.SuffixTrie; import java.util.Random; public class Trainer { private Properties props_; private Random rng_; private Merger merger_; private Model model_; private Splitter splitter_; private EmTrainer em_trainer_; private HmmTrainer hmm_trainer_; private Iterable<Sentence> train_reader_; private Iterable<Sentence> test_reader_; private Smoother smoother_; public Trainer(Properties props) { props_ = props; rng_ = new Random(props.getSeed()); Mapping map = null; if (props_.getUniversalPos()) { map = new Mapping(props_.getUniversalPosFile()); } train_reader_ = new BufferedIterable<Sentence>(new PosReader( props_.getTrainFile(), map)); if (props_.getTest()) { test_reader_ = new BufferedIterable<Sentence>(new PosReader( props_.getTestFile(), map)); } model_ = new Model(train_reader_, props_); if (!props.getLanguage().equals("en")) { SuffixTrie trie = Ling.getSuffixes(new PosReader(props_ .getTrainFile())); model_.setSuffixTrie(trie); } // Splitter setup. splitter_ = new Splitter(props_.getRandomness(), rng_); // EM trainer setup. if (props_.getNumThreads() == 1) { em_trainer_ = new SimpleEmTrainer(); } else { em_trainer_ = new ConcurrentEmTrainer(props_.getNumThreads()); } hmm_trainer_ = HmmTrainerFactory.getTrainer(props_); // Merger setup. LossEstimator estimator; if (props_.getExactLoss()) { estimator = new ExactLossEstimator(em_trainer_, hmm_trainer_); } else { estimator = new ApproximativeLossEstimator(hmm_trainer_); } merger_ = new Merger(estimator); smoother_ = SmootherFactory.getSmoother(props_); } private void runEm() { Iterable<Sentence> em_reader; if (props_.getSample()) { em_reader = new RandomIterable<Sentence>(train_reader_, props_.getSamplingFraction(), rng_); } else { em_reader = train_reader_; } int step = 0; while (step < props_.getEmSteps()) { em_trainer_.estep(hmm_trainer_, model_, em_reader); model_.setStatistics(smoother_.smooth(model_)); step += 1; } } private void split() { splitter_.split(model_); runEm(); } private void merge() { merger_.merge(model_, train_reader_, props_.getMergeFactor()); runEm(); } private void run() { eval(); while (model_.getNumTags() < props_.getNumTags()) { split(); if (props_.getSample()) { em_trainer_.estep(hmm_trainer_, model_, train_reader_); } if (props_.getMerge()) { merge(); } if (props_.getSample()) { em_trainer_.estep(hmm_trainer_, model_, train_reader_); } eval(); if (props_.getDumpIntermediateModels()) { model_.saveToFile(props_.getIntermediateModelName(model_)); } } model_.saveToFile(props_.getModelFile()); } private void eval() { Decoder decoder = new CoarseToFineDecoder(model_, hmm_trainer_, true, false); int tagsize = model_.getTagTable().size() - 1; System.err.format("Tag size: %d", tagsize); if (test_reader_ != null) { Result result = Eval.eval(decoder, test_reader_, model_); System.err.format(" Acc: %s", result.toString()); } System.err.format("\n"); } public static void main(String[] args) { Properties props = new Properties(); if (args.length == 0) { props.usage(); return; } props.setPropertiesFromStrings(args); props.check(Trainer.class.getSimpleName()); props.writePropertiesToFile(props.getModelFile() + ".props"); Trainer pipeline = new Trainer(props); pipeline.run(); } }