// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma.toutanova; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.logging.Logger; import lemming.lemma.LemmaInstance; import lemming.lemma.LemmaOptions; import lemming.lemma.LemmatizerGenerator; import lemming.lemma.LemmatizerGeneratorTrainer; import marmot.util.DynamicWeights; public class ToutanovaTrainer implements LemmatizerGeneratorTrainer { public static class ToutanovaOptions extends LemmaOptions { private static final long serialVersionUID = 1L; public static final String FILTER_ALPHABET = "filter-alphabet"; public static final String ALIGNER_TRAINER = "aligner-trainer"; public static final String DECODER = "decoder"; public static final String MAX_COUNT = "max-count"; public static final String NBEST_RANK = "nbest-rank"; public static final String WINDOW_SIZE = "window-size"; public ToutanovaOptions() { super(); map_.put(FILTER_ALPHABET, 5); map_.put(ALIGNER_TRAINER, EditTreeAlignerTrainer.class); map_.put(DECODER, ZeroOrderDecoder.class); map_.put(MAX_COUNT, 1); map_.put(NBEST_RANK, 50); map_.put(WINDOW_SIZE, 2); } public static ToutanovaOptions newInstance() { return new ToutanovaOptions(); } public int getFilterAlphabet() { return (Integer) getOption(FILTER_ALPHABET); } public AlignerTrainer getAligner() { return (AlignerTrainer) getInstance(ALIGNER_TRAINER); } public Decoder getDecoderInstance() { return (Decoder) getInstance(DECODER); } public int getMaxCount() { return (Integer) getOption(MAX_COUNT); } public int getNbestRank() { return (Integer) getOption(NBEST_RANK); } public int getMaxWindowSize() { return (Integer) getOption(WINDOW_SIZE ); } } private ToutanovaOptions options_; public ToutanovaTrainer() { options_ = new ToutanovaOptions(); } public static List<ToutanovaInstance> createToutanovaInstances( List<LemmaInstance> instances, Aligner aligner) { List<ToutanovaInstance> new_instances = new LinkedList<>(); for (LemmaInstance instance : instances) { List<Integer> alignment = null; if (aligner != null) { alignment = aligner.align(instance.getForm(), instance.getLemma()); assert alignment != null; } new_instances.add(new ToutanovaInstance(instance, alignment)); } return new_instances; } @Override public LemmatizerGenerator train(List<LemmaInstance> train_instances, List<LemmaInstance> dev_instances) { AlignerTrainer aligner_trainer = options_.getAligner(); Aligner aligner = aligner_trainer.train(train_instances); List<ToutanovaInstance> new_train_instances = createToutanovaInstances( train_instances, aligner); List<ToutanovaInstance> new_dev_instances = null; if (dev_instances != null) { new_dev_instances = createToutanovaInstances(dev_instances, null); } return trainAligned(new_train_instances, new_dev_instances); } public LemmatizerGenerator trainAligned(List<ToutanovaInstance> train_instances, List<ToutanovaInstance> dev_instances) { Logger logger = Logger.getLogger(getClass().getName()); ToutanovaModel model = new ToutanovaModel(); model.init(options_, train_instances, dev_instances); DynamicWeights weights = model.getWeights(); DynamicWeights sum_weights = null; if (options_.getAveraging()) { sum_weights = new DynamicWeights(null); } Decoder decoder = (Decoder) options_.getDecoderInstance(); decoder.init(model); double correct; double total; int number; List<ToutanovaInstance> token_instances = new LinkedList<>(); for (ToutanovaInstance instance : train_instances) { if (!instance.isRare()) { for (int i = 0; i < Math.min(options_.getMaxCount(), instance .getInstance().getCount()); i++) { token_instances.add(instance); } } } for (int iter = 0; iter < options_.getNumIterations(); iter++) { logger.info(String.format("Iter: %3d / %3d", iter + 1, options_.getNumIterations())); correct = 0; total = 0; number = 0; Collections.shuffle(token_instances, options_.getRandom()); for (ToutanovaInstance instance : token_instances) { Result result = decoder.decode(instance); String output = result.getOutput(); if (!output.equals(instance.getInstance().getLemma())) { model.update(instance, result, -1.); model.update(instance, instance.getResult(), +1.); if (sum_weights != null) { double amount = token_instances.size() - number; assert amount > 0; model.setWeights(sum_weights); sum_weights = model.getWeights(); model.update(instance, result, -amount); model.update(instance, instance.getResult(), +amount); model.setWeights(weights); weights = model.getWeights(); } } else { correct++; } total++; number++; if (number % 1000 == 0 && options_.getVerbosity() > 0) { logger.info(String.format("Processed: %3d / %3d", number, token_instances.size())); } } if (sum_weights != null) { double weights_scaling = 1. / ((iter + 1.) * token_instances .size()); double sum_weights_scaling = (iter + 2.) / (iter + 1.); for (int i = 0; i < weights.getLength(); i++) { weights.set(i, sum_weights.get(i) * weights_scaling); sum_weights.set(i, sum_weights.get(i) * sum_weights_scaling); } } logger.info(String.format("Train Accuracy: %g / %g = %g", correct, total, correct * 100. / total)); } return new ToutanovaLemmatizer(options_, model); } @Override public LemmaOptions getOptions() { return options_; } }