// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma.edit; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import lemming.lemma.LemmaCandidateGenerator; import lemming.lemma.LemmaCandidateGeneratorTrainer; import lemming.lemma.LemmaInstance; import lemming.lemma.LemmaOptions; import marmot.util.Counter; import marmot.util.edit.EditTree; import marmot.util.edit.EditTreeBuilder; import marmot.util.edit.EditTreeBuilderTrainer; public class EditTreeGeneratorTrainer implements LemmaCandidateGeneratorTrainer { private EditTreeGeneratorTrainerOptions options_; public static class EditTreeGeneratorTrainerOptions extends LemmaOptions { private static final long serialVersionUID = 1L; public static final String MIN_COUNT = "min-count"; public static final String NUM_STEPS = "num-steps"; public static final String TAG_DEPENDENT = "tag-dependent"; public static final String UNKNOWN = "unknown"; public static final String MAX_DEPTH = "max-depth"; public EditTreeGeneratorTrainerOptions() { map_.put(MIN_COUNT, 1); map_.put(NUM_STEPS, 1); map_.put(TAG_DEPENDENT, false); map_.put(UNKNOWN, "<UNKNOWN>"); map_.put(MAX_DEPTH, -1); } public int getNumSteps() { return (Integer) getOption(NUM_STEPS); } public boolean getIsTagDependent() { return (Boolean) getOption(TAG_DEPENDENT); } public String getUnknown() { return (String) getOption(UNKNOWN); } public Integer getMinCount() { return (Integer) getOption(MIN_COUNT); } public int getMaxDepth() { return (Integer) getOption(MAX_DEPTH); } } public EditTreeGeneratorTrainer() { options_ = new EditTreeGeneratorTrainerOptions(); } @Override public LemmaCandidateGenerator train(List<LemmaInstance> instances, List<LemmaInstance> dev_instances) { EditTreeBuilder builder = new EditTreeBuilderTrainer( options_.getRandom(), options_.getNumSteps(), options_.getMaxDepth()).train(instances); Map<String, Counter<EditTree>> map = new HashMap<String, Counter<EditTree>>(); map.put(options_.getUnknown(), new Counter<EditTree>()); for (LemmaInstance instance : instances) { String form = instance.getForm(); String lemma = instance.getLemma(); EditTree tree = builder.build(form, lemma); Counter<EditTree> counter = map.get(options_.getUnknown()); counter.increment(tree, 1.0); if (options_.getIsTagDependent()) { String tag = instance.getPosTag(); if (tag != null) { counter = map.get(tag); if (counter == null) { counter = new Counter<>(); map.put(tag, counter); } counter.increment(tree, 1.0); } } } Map<String, List<EditTree>> list_map = new HashMap<>(); for (Map.Entry<String, Counter<EditTree>> map_entry : map.entrySet()) { List<EditTree> list = new LinkedList<>(); Counter<EditTree> counter = map_entry.getValue(); list_map.put(map_entry.getKey(), list); for (Map.Entry<EditTree, Double> entry : counter.entrySet()) { double count = entry.getValue(); if (count >= options_.getMinCount()) { EditTree tree = entry.getKey(); list.add(tree); } } } return new EditTreeGenerator(options_.getUnknown(), list_map); } @Override public LemmaOptions getOptions() { return options_; } }