// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.test.util.edit;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.logging.Logger;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.EditTreeAligner;
import lemming.lemma.toutanova.Aligner.Pair;
import marmot.morph.io.SentenceReader;
import marmot.util.Numerics;
import marmot.util.edit.EditTree;
import marmot.util.edit.EditTreeBuilder;
import marmot.util.edit.EditTreeBuilderTrainer;
import org.junit.Test;
public class EditTreeBuilderTrainerTest {
@Test
public void test() {
String indexes = "form-index=4,lemma-index=5,tag-index=2,";
String trainfile = indexes+ getResourceFile("trn_mod.tsv");
List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
EditTreeBuilderTrainer trainer = new EditTreeBuilderTrainer(new Random(42), 1, -1);
EditTreeBuilder builder = trainer.train(instances);
EditTreeAligner aligner = new EditTreeAligner(builder, true);
testAligner(aligner, "umgezogen", "umziehen", Arrays.asList("u", "m", "ge", "z", "og", "e", "n"), Arrays.asList("u", "m", "", "z", "ieh", "e", "n"));
testAligner(aligner, "gebissen", "beißen", Arrays.asList("ge", "b", "i", "ss", "e", "n" ), Arrays.asList("", "be", "i", "ß", "e", "n" ));
testAligner(aligner, "gebogen", "biegen", Arrays.asList("ge", "b", "o", "g", "e", "n"), Arrays.asList("", "b", "ie", "g", "e", "n"));
}
@Test
public void testApply() {
String indexes = "form-index=4,lemma-index=5,tag-index=2,";
String trainfile = indexes+ getResourceFile("trn_mod.tsv");
List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
EditTreeBuilderTrainer trainer = new EditTreeBuilderTrainer(new Random(42), 1, -1);
EditTreeBuilder builder = trainer.train(instances);
testHashAndEquals(builder, "loves", "love", "hates", "hate", true);
testHashAndEquals(builder, "lachen", "gelacht", "machen", "gemacht", true);
testHashAndEquals(builder, "lachen", "gelacht", "aaaaaaaaen", "geaaaaaaaat", true);
Map<EditTree, List<LemmaInstance>> map = new HashMap<>();
for (LemmaInstance instance : instances) {
String input = instance.getForm();
String output = instance.getLemma();
EditTree tree = builder.build(input, output);
String p_output = tree.apply(input, 0, input.length());
assertEquals(output, p_output);
List<LemmaInstance> list = map.get(tree);
if (list == null) {
list = new LinkedList<>();
map.put(tree, list);
}
list.add(instance);
}
applyTest(map, instances, false, 0.0);
applyTest(map, LemmaInstance.getInstances(indexes + getResourceFile("dev.tsv")), false, 0.02526);
}
private void applyTest(Map<EditTree, List<LemmaInstance>> map,
List<LemmaInstance> instances, boolean log_missed_outputs, double expected_miss_rate) {
Logger logger = Logger.getLogger(getClass().getName());
int missed_outputs = 0;
for (LemmaInstance instance : instances) {
String input = instance.getForm();
String output = instance.getLemma();
Set<String> outputs = new HashSet<>();
for (EditTree tree : map.keySet()) {
String poutput = tree.apply(input, 0, input.length());
if (poutput != null) {
outputs.add(poutput);
}
}
if (!outputs.contains(output)) {
missed_outputs ++;
if (log_missed_outputs)
logger.info(String.format("Missed: %s", instance));
}
assertTrue(outputs.contains(input));
}
double missed_rate = missed_outputs * 1.0 / instances.size();
logger.info(Double.toString(missed_rate));
assertTrue(Numerics.approximatelyLesserEqual(missed_rate, expected_miss_rate));
}
private void testHashAndEquals(EditTreeBuilder builder, String input_a,
String output_a, String input_b, String output_b, boolean result) {
EditTree tree_a = builder.build(input_a, output_a);
EditTree tree_b = builder.build(input_b, output_b);
assertEquals(result, tree_a.equals(tree_b));
assertEquals(result, tree_a.hashCode() == tree_b.hashCode());
}
public void testAligner(EditTreeAligner aligner, String input,
String output, List<String> input_segments, List<String> output_segments) {
List<Integer> indexes = aligner.align(input, output);
List<Pair> pairs = Pair.toPairs(input, output, indexes);
List<String> real_input_segments = new LinkedList<>();
List<String> real_output_segments = new LinkedList<>();
for (Pair pair : pairs) {
real_input_segments.add(pair.getInputSegment());
real_output_segments.add(pair.getOutputSegment());
}
assertEquals(input_segments, real_input_segments);
assertEquals(output_segments, real_output_segments);
}
protected String getResourceFile(String name) {
return String.format("res:///%s/%s", "marmot/test/lemma", name);
}
}