// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.lemma.toutanova;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
public class FirstOrderDecoder implements Decoder {
private ToutanovaModel model_;
private double[] score_array_;
private int[] output_array_;
private int[] index_array_;
private int num_output_symbols_;
private int input_length_;
private ToutanovaInstance instance_;
public void init(ToutanovaModel model) {
model_ = model;
num_output_symbols_ = model_.getOutputTable().size();
}
public Result decode(ToutanovaInstance instance) {
int max_input_segment_length = model_.getMaxInputSegmentLength();
input_length_ = instance.getFormCharIndexes().length;
instance_ = instance;
checkArraySize(num_output_symbols_ * input_length_);
Arrays.fill(score_array_, Double.NEGATIVE_INFINITY);
Arrays.fill(output_array_, -1);
Arrays.fill(index_array_, -1);
for (int l = 1; l < input_length_ + 1; l++) {
for (int o = 0; o < num_output_symbols_; o++) {
double best_score = Double.NEGATIVE_INFINITY;
int best_output = -1;
int best_index = -1;
for (int l_start = Math.max(0, l - max_input_segment_length); l_start < l; l_start++) {
double pair_score = model_.getPairScore(instance, l_start, l, o);
if (l_start == 0) {
double score = pair_score;
if (score > best_score) {
best_score = score;
best_output = -1;
best_index = l_start;
}
} else {
for (int last_o = 0; last_o < num_output_symbols_; last_o++) {
double prev_cost = score_array_[getIndex(last_o, l_start - 1)];
double transiton_score = model_.getTransitionScore(instance, last_o, o,
l_start, l);
double score = pair_score + transiton_score + prev_cost;
if (score > best_score) {
best_score = score;
best_output = last_o;
best_index = l_start;
}
}
}
}
score_array_[getIndex(o, l - 1)] = best_score;
output_array_[getIndex(o, l - 1)] = best_output;
index_array_[getIndex(o, l - 1)] = best_index;
}
}
Result result = backTrace();
return result;
}
private Result backTrace() {
List<Integer> outputs = new LinkedList<>();
List<Integer> inputs = new LinkedList<>();
int end_index = input_length_;
double best_score = Double.NEGATIVE_INFINITY;
int end_output = -1;
for (int o = 0; o < num_output_symbols_; o++) {
double score = score_array_[getIndex(o, end_index - 1)];
// System.err.format("End score: %s %s %g\n", instance_.getInstance().getLemma(), model_.getOutput(o), score);
if (score > best_score) {
best_score = score;
end_output = o;
}
}
outputs.add(end_output);
inputs.add(end_index);
while (true) {
int start_index = index_array_[getIndex(end_output, end_index - 1)];
int start_output = output_array_[getIndex(end_output, end_index - 1)];
if (start_output < 0)
break;
outputs.add(start_output);
inputs.add(start_index);
end_output = start_output;
end_index = start_index;
}
Collections.reverse(outputs);
Collections.reverse(inputs);
return new Result(model_, outputs, inputs, instance_.getInstance().getForm(), best_score);
}
private int getIndex(int output, int index) {
return output * input_length_ + index;
}
private void checkArraySize(int required_length) {
if (score_array_ == null || score_array_.length < required_length) {
score_array_ = new double[required_length];
output_array_ = new int[score_array_.length];
index_array_ = new int[score_array_.length];
}
}
@Override
public int getOrder() {
return 1;
}
}