// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.core; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Random; import marmot.core.lattice.SumLattice; public class CrfTrainer implements Trainer { private double penalty_; private double step_width_ = .1; private double steps_; private boolean shuffle_; private boolean verbose_; private boolean very_verbose_; private double quadratic_penalty_; private long seed_; private boolean optimize_num_iterations_; @Override public void train(Tagger tagger, Collection<Sequence> in_sequences, Evaluator evaluator) { if (optimize_num_iterations_) { assert evaluator != null : "Set optimize_num_iterations but did not provide test data."; } Random rng = null; if (shuffle_) { if (seed_ == 0) { rng = new Random(); } else { rng = new Random(seed_); } } List<Sequence> sequences = new ArrayList<Sequence>(in_sequences); int fraction = Math.max(sequences.size() / 4, 1); int smaller_fraction = Math.max(sequences.size() / 4000, 1); int small_factor = 1; WeightVector weights = tagger.getWeightVector(); assert weights != null; double[] best_float_params = null; double[] best_params = null; double best_score = 0.0; double accumalted_penalty = 0; int number = 0; for (int step = 0; step < steps_; step++) { if (verbose_) System.err.println("step: " + step); if (shuffle_) Collections.shuffle(sequences, rng); int current_sentence = 0; long train_time = System.currentTimeMillis(); for (Sequence sequence : sequences) { double step_width = step_width_ / (1 + (number / (double) sequences.size())); double scale_factor = 1 - 2. * step_width * quadratic_penalty_ / sequences.size(); assert !Double.isNaN(scale_factor); assert !Double.isInfinite(scale_factor); assert scale_factor > 1e-10; assert scale_factor < 1 + 1e-10; step_width /= scale_factor; if (Math.abs(penalty_) > 1e-10) { accumalted_penalty += step_width * penalty_ / sequences.size(); weights.setPenalty(true, accumalted_penalty); } SumLattice lattice = tagger.getSumLattice(true, sequence); if (very_verbose_) { System.err.format("vv %d %d %d %d\n", number, lattice.getOrder() + lattice.getLevel() * (tagger.getModel().getOrder() + 1), lattice.getLevel(), lattice.getOrder() ); } assert lattice != null; lattice.update(weights, step_width); weights.scaleBy(scale_factor); current_sentence++; if (current_sentence % fraction == 0) { if (verbose_) System.err .format("Processed %d sentences at %g sentence/s \n", current_sentence, current_sentence / ((System.currentTimeMillis() - train_time) / 1000.)); if (small_factor < 100) { small_factor *= 10; smaller_fraction = Math.max(small_factor * sequences.size() / 400, 1); } } if (current_sentence % smaller_fraction == 0) { tagger.setThresholds(false); } number++; } if (evaluator != null && (verbose_ || optimize_num_iterations_)) { weights.setExtendFeatureSet(false); Result result = evaluator.eval(tagger); weights.setExtendFeatureSet(true); tagger.setResult(result); if (verbose_) System.err.println(result); if (optimize_num_iterations_) { double score = result.getScore(); if (score > best_score) { best_score = score; best_params = weights.getWeights().clone(); best_float_params = weights.getFloatWeights().clone(); } } } } weights.setPenalty(false, 0.0); weights.setExtendFeatureSet(false); if (optimize_num_iterations_) { if (best_params != null) { assert weights.getWeights().length == best_params.length; weights.setWeights(best_params); } if (best_float_params != null) { weights.setFloatWeights(best_float_params); } if (evaluator != null) { Result result = evaluator.eval(tagger); tagger.setResult(result); } } } @Override public void setOptions(Options options) { setOptions(options.getPenalty(), options.getQuadraticPenalty(), options.getNumIterations(), options .getShuffle(), options.getVerbose(), options.getVeryVerbose(), options.getSeed(), options.getOptimizeNumIterations()); } private void setOptions(double penalty, double quadratic_penalty, int steps, boolean shuffle, boolean verbose, boolean very_verbose, long seed, boolean optimize_num_iterations) { penalty_ = penalty; steps_ = steps; shuffle_ = shuffle; verbose_ = verbose; very_verbose_ = very_verbose; quadratic_penalty_ = quadratic_penalty; seed_ = seed; optimize_num_iterations_ = optimize_num_iterations; } }