// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma.ranker; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; import lemming.lemma.LemmaCandidateGenerator; import lemming.lemma.LemmaCandidateGeneratorTrainer; import lemming.lemma.LemmaInstance; import lemming.lemma.LemmaOptions; import lemming.lemma.LemmatizerGenerator; import lemming.lemma.LemmatizerGeneratorTrainer; import lemming.lemma.SimpleLemmatizerTrainer; import lemming.lemma.edit.EditTreeGeneratorTrainer; import lemming.lemma.edit.EditTreeGeneratorTrainer.EditTreeGeneratorTrainerOptions; import lemming.lemma.ranker.RankerTrainer.RankerTrainerOptions; import lemming.lemma.toutanova.EditTreeAligner; import lemming.lemma.toutanova.EditTreeAlignerTrainer; import marmot.util.Sys; import cc.mallet.optimize.LimitedMemoryBFGS; import cc.mallet.optimize.Optimizable.ByGradientValue; import cc.mallet.optimize.OptimizationException; import cc.mallet.optimize.Optimizer; public class RankerTrainer implements LemmatizerGeneratorTrainer { public static class RankerTrainerOptions extends LemmaOptions { private static final long serialVersionUID = 1L; public static final String GENERATOR_TRAINERS = "generator-trainers"; public static final String USE_PERCEPTRON = "use-perceptron"; public static final String QUADRATIC_PENALTY = "quadratic-penalty"; public static final String UNIGRAM_FILE = "unigram-file"; public static final String USE_SHAPE_LEXICON = "use-shape-lexicon"; public static final String ASPELL_LANG = "aspell-lang"; public static final String ASPELL_PATH = "aspell-path"; public static final String USE_CORE_FEATURES = "use-core-features"; public static final String USE_ALIGNMENT_FEATURES = "use-alignment-features"; public static final String IGNORE_FEATURES = "ignore-features"; public static final String NUM_EDIT_TREE_STEPS = "num-edit-tree-steps"; public static final String COPY_CONJUNCTONS = "copy-conjunctions"; public static final String TAG_DEPENDENT = "tag-dependent"; public static final String EDIT_TREE_MIN_COUNT = "edit-tree-min-count"; public static final String EDIT_TREE_MAX_DEPTH = "edit-tree-max-depth"; public static final String USE_HASH_FEATURE_TABLE = "use-hash-feature-table"; public static final String USE_MALLET = "use-mallet"; public static final String OFFLINE_FEATURE_EXTRACTION = "offline-feature-extraction"; public static final String CLUSTER_FILE = "cluster-file"; public RankerTrainerOptions() { map_.put(GENERATOR_TRAINERS, Arrays.asList( SimpleLemmatizerTrainer.class, EditTreeGeneratorTrainer.class)); map_.put(USE_PERCEPTRON, true); map_.put(QUADRATIC_PENALTY, 0.00); map_.put(UNIGRAM_FILE, Arrays.asList("")); map_.put(USE_SHAPE_LEXICON, false); map_.put(ASPELL_LANG, ""); map_.put(ASPELL_PATH, ""); map_.put(USE_CORE_FEATURES, true); map_.put(USE_ALIGNMENT_FEATURES, true); map_.put(IGNORE_FEATURES, ""); map_.put(NUM_EDIT_TREE_STEPS, 1); map_.put(COPY_CONJUNCTONS, false); map_.put(USE_HASH_FEATURE_TABLE, false); map_.put(TAG_DEPENDENT, false); map_.put(EDIT_TREE_MIN_COUNT, 0); map_.put(EDIT_TREE_MAX_DEPTH, -1); map_.put(USE_MALLET, true); map_.put(OFFLINE_FEATURE_EXTRACTION, true); map_.put(CLUSTER_FILE, ""); } public RankerTrainerOptions(RankerTrainerOptions roptions) { map_ = new HashMap<>(roptions.map_); } @SuppressWarnings("unchecked") public List<Object> getUnigramFile() { return (List<Object>) getOption(UNIGRAM_FILE); } @SuppressWarnings("unchecked") public List<Object> getGeneratorTrainers() { return (List<Object>) getOption(GENERATOR_TRAINERS); } public boolean getUsePerceptron() { return (Boolean) getOption(USE_PERCEPTRON); } public double getQuadraticPenalty() { return (Double) getOption(QUADRATIC_PENALTY); } public List<LemmaCandidateGenerator> getGenerators( List<LemmaInstance> instances) { List<LemmaCandidateGenerator> generators = new LinkedList<>(); for (Object trainer_class : getGeneratorTrainers()) { LemmaCandidateGeneratorTrainer trainer = (LemmaCandidateGeneratorTrainer) toInstance((Class<?>) trainer_class); if (trainer instanceof EditTreeGeneratorTrainer) { trainer.getOptions().setOption( EditTreeGeneratorTrainerOptions.NUM_STEPS, getNumEditTreeSteps()); trainer.getOptions().setOption( EditTreeGeneratorTrainerOptions.TAG_DEPENDENT, getTagDependent()); trainer.getOptions().setOption( EditTreeGeneratorTrainerOptions.MIN_COUNT, getEditTreeMinCount()); trainer.getOptions().setOption( EditTreeGeneratorTrainerOptions.MAX_DEPTH, getEditTreeMaxDepth()); } generators.add(trainer.train(instances, null)); } return generators; } private Integer getEditTreeMaxDepth() { return (Integer) getOption(EDIT_TREE_MAX_DEPTH); } private Integer getEditTreeMinCount() { return (Integer) getOption(EDIT_TREE_MIN_COUNT); } public boolean getTagDependent() { return (Boolean) getOption(TAG_DEPENDENT); } public boolean getUseShapeLexicon() { return (Boolean) getOption(USE_SHAPE_LEXICON); } public String getAspellPath() { return (String) getOption(ASPELL_PATH); } public String getAspellLang() { return (String) getOption(ASPELL_LANG); } public boolean getUseCoreFeatures() { return (Boolean) getOption(USE_CORE_FEATURES); } public boolean getUseAlignmentFeatures() { return (Boolean) getOption(USE_ALIGNMENT_FEATURES); } public String getIgnoreFeatures() { return (String) getOption(IGNORE_FEATURES); } public int getNumEditTreeSteps() { return (Integer) getOption(NUM_EDIT_TREE_STEPS); } public boolean getCopyConjunctions() { return (Boolean) getOption(COPY_CONJUNCTONS); } public boolean getUseHashFeatureTable() { return (Boolean) getOption(USE_HASH_FEATURE_TABLE); } public boolean getUseMallet() { return (Boolean) getOption(USE_MALLET); } public boolean getUseOfflineFeatureExtraction() { return (Boolean) getOption(OFFLINE_FEATURE_EXTRACTION); } public String getClusterFile() { return (String) getOption(CLUSTER_FILE); } } private RankerTrainerOptions options_; private static final int MAX_NUM_DUPLICATES_ = 3; public RankerTrainer() { options_ = new RankerTrainerOptions(); } @Override public LemmatizerGenerator train(List<LemmaInstance> train_instances, List<LemmaInstance> test_instances) { List<LemmaCandidateGenerator> generators = options_ .getGenerators(train_instances); return trainReranker(generators, train_instances); } private LemmatizerGenerator trainReranker( List<LemmaCandidateGenerator> generators, List<LemmaInstance> simple_instances) { List<RankerInstance> instances = RankerInstance.getInstances( simple_instances, generators); RankerModel model = new RankerModel(); EditTreeAligner aligner = (EditTreeAligner) new EditTreeAlignerTrainer( options_.getRandom(), false, 1, -1).train(simple_instances); Logger logger = Logger.getLogger(getClass().getName()); logger.info("Extracting features"); model.init(options_, instances, aligner); if (options_.getUsePerceptron()) { runPerceptron(model, instances); } else { runMaxEnt(model, instances); } return new Ranker(model, generators); } private void runMaxEnt(RankerModel model, List<RankerInstance> instances) { if (options_.getUseMallet()) { runMallet(model, instances); } else { runSgd(model, instances); } } private void runSgd(RankerModel model, List<RankerInstance> instances) { List<RankerInstance> duplicates = new LinkedList<>(); for (RankerInstance instance : instances) { double count = instance.getInstance().getCount(); int number = Math.min(MAX_NUM_DUPLICATES_, (int) count); for (int i = 0; i < number; i++) { duplicates.add(instance); } } Logger logger = Logger.getLogger(getClass().getName()); logger.info(String.format( "Created duplicates. Increased num instances from %d to %d.\n", instances.size(), duplicates.size())); // instances = new LinkedList<>(instances); instances = duplicates; double initial_step_width = 0.1; RankerObjective objective = new RankerObjective(options_, model, instances, MAX_NUM_DUPLICATES_); Random random = options_.getRandom(); int number = 0; for (int step = 0; step < options_.getNumIterations(); step++) { logger.info("SGD step: " + step); Collections.shuffle(instances, random); for (RankerInstance instance : instances) { double step_width = initial_step_width / (1 + (number / (double) instances.size())); objective.update(instance, true, step_width); number++; } } } private void runMallet(RankerModel model, List<RankerInstance> instances) { Logger logger = Logger.getLogger(getClass().getName()); double memory_used_before_optimization = Sys.getUsedMemoryInMegaBytes(); double memory_usage_of_one_weights_array = (double) model.getWeights().length * (double) Double.SIZE / (8. * 1024. * 1024.); logger.info(String.format("Memory usage of weights array: %g (%g) MB", Sys.getUsedMemoryInMegaBytes(model.getWeights(), false), memory_usage_of_one_weights_array)); logger.info(String.format("Memory usage: %g / %g MB", memory_used_before_optimization, Sys.getMaxHeapSizeInMegaBytes())); logger.info("Start optimization"); ByGradientValue objective = new RankerObjective(options_, model, instances); Optimizer optimizer = new LimitedMemoryBFGS(objective); // Optimizer optimizer = new ConjugateGradient(objective); Logger.getLogger(optimizer.getClass().getName()).setLevel(Level.OFF); objective.setParameters(model.getWeights()); try { optimizer.optimize(1); double memory_usage_during_optimization = Sys .getUsedMemoryInMegaBytes(); logger.info(String.format( "Memory usage after first iteration: %g / %g MB", memory_usage_during_optimization, Sys.getMaxHeapSizeInMegaBytes())); for (int i = 0; i < 200 && !optimizer.isConverged(); i++) { optimizer.optimize(1); logger.info(String.format("Iteration: %3d / %3d", i + 1, 200)); } } catch (IllegalArgumentException e) { } catch (OptimizationException e) { } logger.info("Finished optimization"); } private void runPerceptron(RankerModel model, List<RankerInstance> instances) { Logger logger = Logger.getLogger(getClass().getName()); double[] weights = model.getWeights(); double[] sum_weights = null; if (options_.getAveraging()) { sum_weights = new double[weights.length]; } for (int iter = 0; iter < options_.getNumIterations(); iter++) { double error = 0; double total = 0; int number = 0; Collections.shuffle(instances, options_.getRandom()); for (RankerInstance instance : instances) { String lemma = model.select(instance); if (!lemma.equals(instance.getInstance().getLemma())) { model.update(instance, lemma, -1); model.update(instance, instance.getInstance().getLemma(), +1); if (sum_weights != null) { double amount = instances.size() - number; assert amount > 0; model.setWeights(sum_weights); model.update(instance, lemma, -amount); model.update(instance, instance.getInstance() .getLemma(), +amount); model.setWeights(weights); } error += instance.getInstance().getCount(); } total += instance.getInstance().getCount(); number++; } if (sum_weights != null) { double weights_scaling = 1. / ((iter + 1.) * instances.size()); double sum_weights_scaling = (iter + 2.) / (iter + 1.); for (int i = 0; i < weights.length; i++) { weights[i] = sum_weights[i] * weights_scaling; sum_weights[i] = sum_weights[i] * sum_weights_scaling; } } logger.info(String.format("Train Accuracy: %g / %g = %g", total - error, total, (total - error) * 100. / total)); } } @Override public LemmaOptions getOptions() { return options_; } public void setOptions(RankerTrainerOptions roptions) { options_ = roptions; } }