// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Logger; import lemming.lemma.toutanova.ToutanovaTrainer; public class BackupLemmatizerTrainer implements LemmatizerGeneratorTrainer { public static class BackupLemmatizerTrainerOptions extends LemmaOptions { private static final long serialVersionUID = 1L; public static final String LEMMATIZER_TRAINER = "lemmatizer-trainer"; public static final String BACKUP_TRAINER = "backup-trainer"; public LemmatizerGeneratorTrainer trainer_; public LemmatizerGeneratorTrainer backup_trainer_; public static final String TRAINER_PREF = "backup-lemmatizer-model-"; public static final String BACKUP_PREF = "backup-lemmatizer-backup-"; private Map<String, Object> model_options_; private Map<String, Object> backup_options_; public BackupLemmatizerTrainerOptions() { map_.put(LEMMATIZER_TRAINER, SimpleLemmatizerTrainer.class.getName()); map_.put(BACKUP_TRAINER, ToutanovaTrainer.class.getName()); model_options_ = new HashMap<>(); backup_options_ = new HashMap<>(); } @Override public LemmaOptions setOption(String name, Object value) { name = name.toLowerCase(); if (name.startsWith(TRAINER_PREF)) { model_options_.put(name.substring(TRAINER_PREF.length()), value); } else if (name.startsWith(BACKUP_PREF)) { backup_options_.put(name.substring(BACKUP_PREF.length()), value); } else { super.setOption(name, value); } return this; } public LemmatizerGeneratorTrainer getLemmatizerTrainer(String name, Map<String, Object> map) { String classname = (String) getOption(name); LemmatizerGeneratorTrainer trainer; try { trainer = (LemmatizerGeneratorTrainer) Class.forName(classname).newInstance(); } catch (InstantiationException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { throw new RuntimeException(e); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } for (Map.Entry<String, Object> entry : map.entrySet()) { trainer.getOptions().setOption(entry.getKey(), entry.getValue()); } Logger logger = Logger.getLogger(getClass().getName()); logger.info(String.format("%s (%s) options:\n %s", name, classname, trainer.getOptions().report())); return trainer; } public LemmatizerGeneratorTrainer getLemmatizerTrainer() { return getLemmatizerTrainer(LEMMATIZER_TRAINER, model_options_); } public LemmatizerGeneratorTrainer getBackupTrainer() { return getLemmatizerTrainer(BACKUP_TRAINER, backup_options_); } } BackupLemmatizerTrainerOptions options_; private LemmatizerGeneratorTrainer standard_trainer_; private ToutanovaTrainer backup_trainer_; public BackupLemmatizerTrainer() { standard_trainer_ = null; backup_trainer_ = null; options_ = new BackupLemmatizerTrainerOptions(); } public BackupLemmatizerTrainer(LemmatizerGeneratorTrainer simple_trainer, ToutanovaTrainer trainer) { this(); standard_trainer_ = simple_trainer; backup_trainer_ = trainer; } @Override public LemmatizerGenerator train(List<LemmaInstance> instances, List<LemmaInstance> dev_instances) { LemmatizerGeneratorTrainer trainer; if (standard_trainer_ == null) trainer = options_.getLemmatizerTrainer(); else trainer = standard_trainer_; LemmatizerGenerator lemmatizer = trainer.train(instances, dev_instances); if (backup_trainer_ == null) { trainer = options_.getBackupTrainer(); } else { trainer = backup_trainer_; } LemmatizerGenerator backup = trainer.train(instances, dev_instances); return new BackupLemmatizer(lemmatizer, backup); } @Override public LemmaOptions getOptions() { return options_; } }