// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.test.lemma.toutanova;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.List;
import java.util.logging.Logger;
import junit.framework.AssertionFailedError;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Decoder;
import lemming.lemma.toutanova.NbestDecoder;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaLemmatizer;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ToutanovaTrainer;
import lemming.lemma.toutanova.ZeroOrderDecoder;
import lemming.lemma.toutanova.ZeroOrderNbestDecoder;
import marmot.morph.io.SentenceReader;
import marmot.util.Numerics;
import org.junit.Test;
public class NbestDecoderTest {
private static final double DELTA = 1e-2;
public void trainDecodeTest(String trainfile, String devfile, int num_iters, int rank_max) {
// Train a standard Toutanova model.
List<LemmaInstance> train_instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
ToutanovaTrainer trainer = new ToutanovaTrainer();
ToutanovaLemmatizer lemmatizer = (ToutanovaLemmatizer) trainer.train(train_instances, null);
testDecoder(lemmatizer, devfile, rank_max);
}
private void testDecoder(ToutanovaLemmatizer lemmatizer, String devfile, int rank_max) {
ToutanovaModel model = lemmatizer.getModel();
Decoder decoder = new ZeroOrderDecoder();
decoder.init(model);
NbestDecoder nbest_decoder = new ZeroOrderNbestDecoder(rank_max);
nbest_decoder.init(model);
List<LemmaInstance> test_instances = LemmaInstance.getInstances(new SentenceReader(devfile));
int correct = 0;
int nbest_correct = 0;
int total = 0;
for (LemmaInstance instance : test_instances) {
ToutanovaInstance tinstance = new ToutanovaInstance(instance, null);
model.addIndexes(tinstance, false);
Result result = decoder.decode(tinstance);
double expected_score = model.getScore(tinstance, result);
double first_best_score = result.getScore();
assertEquals(expected_score, first_best_score, DELTA);
List<Result> nbest_results = nbest_decoder.decode(tinstance);
assertTrue(!nbest_results.isEmpty());
Result first_nbest_result = nbest_results.get(0);
assertEquals(result.getOutput(), first_nbest_result.getOutput());
assertEquals(first_best_score, first_nbest_result.getScore(), DELTA);
Result last_result = null;
boolean found_lemma = false;
for (Result nbest_result : nbest_results) {
assertEquals(model.getScore(tinstance, nbest_result), nbest_result.getScore(), DELTA);
if (last_result != null) {
if (!Numerics.approximatelyLesserEqual(nbest_result.getScore(), last_result.getScore())) {
throw new AssertionFailedError(String.format("%g <= %g", nbest_result.getScore(), last_result.getScore()));
}
}
last_result = nbest_result;
if (nbest_result.getOutput().equals(instance.getLemma())) {
found_lemma = true;
}
}
if (found_lemma) {
nbest_correct += instance.getCount();
}
if (result.getOutput().equals(instance.getLemma())) {
correct += instance.getCount();
}
total += instance.getCount();
}
Logger logger = Logger.getLogger(getClass().getName());
logger.info(String.format("One-best : %5d %5d = %g", correct, total, correct * 100. / total));
logger.info(String.format("N-best : %5d %5d = %g", nbest_correct, total, nbest_correct * 100. / total));
}
@Test
public void test() {
String indexes = "form-index=4,lemma-index=5,tag-index=2,";
String train_sml = indexes + getResourceFile("trn_mod.tsv");
String dev = indexes + getResourceFile("dev.tsv");
trainDecodeTest(train_sml, train_sml, 1, 5);
trainDecodeTest(train_sml, dev, 10, 10);
}
protected String getResourceFile(String name) {
return String.format("res:///%s/%s", "marmot/test/lemma", name);
}
}