// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.test.lemma.toutanova;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Decoder;
import lemming.lemma.toutanova.FirstOrderDecoder;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ToutanovaTrainer;
import org.junit.Test;
public class DecoderTest {
@Test
public void test() {
ToutanovaModel model = new ToutanovaModel();
List<ToutanovaInstance> train_instances = new LinkedList<>();
train_instances.add(new ToutanovaInstance(new LemmaInstance("aaae", "aaa", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
train_instances.add(new ToutanovaInstance(new LemmaInstance("bbbe", "bbb", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
train_instances.add(new ToutanovaInstance(new LemmaInstance("ccce", "ccc", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
train_instances.add(new ToutanovaInstance(new LemmaInstance("aaaf", "aaa", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
train_instances.add(new ToutanovaInstance(new LemmaInstance("bbbf", "bbb", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
train_instances.add(new ToutanovaInstance(new LemmaInstance("cccf", "ccc", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
model.init(ToutanovaTrainer.ToutanovaOptions.newInstance(), train_instances, null);
Decoder decoder = new FirstOrderDecoder();
decoder.init(model);
int a_index = model.getOutputTable().toIndex("a");
int b_index = model.getOutputTable().toIndex("b");
int c_index = model.getOutputTable().toIndex("c");
// model.setWeight(model.getTransitionFeatureIndex(a_index, a_index), 1.0);
// model.setWeight(model.getTransitionFeatureIndex(b_index, b_index), 1.0);
// model.setWeight(model.getTransitionFeatureIndex(c_index, c_index), 1.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(0).getFormCharIndexes(), 2, 4, a_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(1).getFormCharIndexes(), 2, 4, b_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(2).getFormCharIndexes(), 2, 4, c_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(3).getFormCharIndexes(), 2, 4, a_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(4).getFormCharIndexes(), 2, 4, b_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(5).getFormCharIndexes(), 2, 4, c_index), 5.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(0).getFormCharIndexes(), 0, 1, a_index), 1.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(1).getFormCharIndexes(), 0, 1, b_index), 1.0);
// model.setWeight(model.getPairFeatureIndex(train_instances.get(2).getFormCharIndexes(), 0, 1, c_index), 1.0);
//
assertResultEquals(Arrays.asList(a_index, a_index, a_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(0)));
assertResultEquals(Arrays.asList(a_index, a_index, a_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(3)));
assertResultEquals(Arrays.asList(b_index, b_index, b_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(1)));
assertResultEquals(Arrays.asList(b_index, b_index, b_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(4)));
assertResultEquals(Arrays.asList(c_index, c_index, c_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(2)));
assertResultEquals(Arrays.asList(c_index, c_index, c_index), Arrays.asList(1, 2, 4), decoder.decode(train_instances.get(5)));
// double test_score = model_.getScore(instance, result);
// assert Math.abs(result.getScore() - test_score) < 1e-5;
}
private void assertResultEquals(List<Integer> outputs, List<Integer> inputs,
Result result) {
assertEquals(outputs, result.getOutputs());
assertEquals(inputs, result.getInputs());
}
}