// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.test.lemma.toutanova;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import junit.framework.AssertionFailedError;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmaResult;
import lemming.lemma.Lemmatizer;
import lemming.lemma.LemmatizerTrainer;
import lemming.lemma.SimpleLemmatizerTrainer;
import lemming.lemma.SimpleLemmatizerTrainer.SimpleLemmatizerTrainerOptions;
import marmot.morph.io.SentenceReader;
import marmot.util.Copy;
import marmot.util.Numerics;
import org.junit.Test;
public class SimpleTrainerTest {
@Test
public void moderateTest() {
runModerateTest(new SimpleLemmatizerTrainer(), 98.41, 64.47);
}
@Test
public void moderateUnseenTest() {
SimpleLemmatizerTrainer trainer = new SimpleLemmatizerTrainer();
trainer.getOptions().setOption(SimpleLemmatizerTrainerOptions.HANDLE_UNSEEN, true);
trainer.getOptions().setOption(SimpleLemmatizerTrainerOptions.USE_POS, false);
runModerateTest(trainer, 99.48, 86.63);
}
@Test
public void moderateUnseenPosTest() {
SimpleLemmatizerTrainer trainer = new SimpleLemmatizerTrainer();
trainer.getOptions().setOption(SimpleLemmatizerTrainerOptions.HANDLE_UNSEEN, true);
trainer.getOptions().setOption(SimpleLemmatizerTrainerOptions.USE_POS, true);
runModerateTest(trainer, 99.96, 86.84);
}
protected String getResourceFile(String name) {
return String.format("res:///%s/%s", "marmot/test/lemma", name);
}
protected List<LemmaInstance> getCopyInstances(List<LemmaInstance> instances) {
List<LemmaInstance> new_instances = new LinkedList<>();
for (LemmaInstance instance : instances) {
if (instance.getForm().equals(instance.getLemma())) {
new_instances.add(instance);
}
}
return new_instances;
}
protected void runSmallTest(LemmatizerTrainer trainer, double train_acc, double test_acc) {
runSmallTest(trainer, train_acc, test_acc, false);
}
protected void runSmallTest(LemmatizerTrainer trainer, double train_acc, double test_acc, boolean add_morph_indexes) {
runTest(trainer, train_acc, test_acc, "trn_sml.tsv", add_morph_indexes);
}
protected void runModerateTest(LemmatizerTrainer trainer, double train_acc, double test_acc) {
runModerateTest(trainer, train_acc, test_acc, false);
}
protected void runModerateTest(LemmatizerTrainer trainer, double train_acc, double test_acc, boolean add_morph_indexes) {
runTest(trainer, train_acc, test_acc, "trn_mod.tsv", add_morph_indexes);
}
private final static String pos_indexes = "form-index=4,lemma-index=5,tag-index=2,";
private final static String morph_indexes = "form-index=4,lemma-index=5,tag-index=2,morph-index=3,";
protected void runTest(LemmatizerTrainer trainer, double train_acc, double test_acc, String trainfile_name) {
runTest(trainer, train_acc, test_acc, trainfile_name, false);
}
protected void runTest(LemmatizerTrainer trainer, double train_acc, double test_acc, String trainfile_name, boolean add_morph_indexes) {
String indexes = pos_indexes;
if (add_morph_indexes) {
indexes = morph_indexes;
}
String trainfile = indexes+ getResourceFile(trainfile_name);
List<LemmaInstance> training_instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
Lemmatizer lemmatizer = trainer.train(training_instances, null);
assertAccuracy(lemmatizer, training_instances, train_acc);
String testfile = indexes + getResourceFile("dev.tsv");
List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(testfile));
assertAccuracy(lemmatizer, instances, test_acc);
testfile = indexes + getResourceFile("dev.tsv.morfette");
instances = LemmaInstance.getInstances(new SentenceReader(testfile));
assertAccuracy(lemmatizer, instances, 1.);
}
protected void testIfLemmatizerIsSerializable(LemmatizerTrainer trainer) {
String trainfile = pos_indexes + getResourceFile("trn_sml.tsv");
Copy.clone(trainer.train(LemmaInstance.getInstances(trainfile), null));
}
protected void assertAccuracy(Lemmatizer lemmatizer, Collection<LemmaInstance> instances, double min_accuracy) {
LemmaResult result = LemmaResult.test(lemmatizer, instances);
double accuracy = result.getTokenAccuracy();
result.logAccuracy();
//result.logErrors(50);
if (!Numerics.approximatelyGreaterEqual(accuracy, min_accuracy)) {
throw new AssertionFailedError(String.format("%g > %g", accuracy, min_accuracy));
}
}
}