// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.lemma.toutanova;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import marmot.util.HashableIntArray;
import marmot.util.Numerics;
public class ZeroOrderNbestDecoder implements NbestDecoder {
private static class State implements Comparable<State> {
private double score;
private int output;
private int index;
@Override
public int compareTo(State state) {
return -Double.compare(score, state.score);
}
}
private ToutanovaModel model_;
private int num_output_symbols_;
private int input_length_;
private ToutanovaInstance instance_;
private int rank_length_;
private State[][] state_array_;
private PriorityQueue<State> queue_;
private Queue<Result> result_queue_;
private Set<HashableIntArray> used_signatures_;
public ZeroOrderNbestDecoder(int queue_size) {
rank_length_ = queue_size;
queue_ = new PriorityQueue<>();
result_queue_ = new PriorityQueue<>();
used_signatures_ = new HashSet<>();
}
@Override
public void init(ToutanovaModel model) {
model_ = model;
num_output_symbols_ = model_.getOutputTable().size();
}
@Override
public List<Result> decode(ToutanovaInstance instance) {
assert model_ != null;
assert num_output_symbols_ > 0;
int max_input_segment_length = model_.getMaxInputSegmentLength();
input_length_ = instance.getFormCharIndexes().length;
instance_ = instance;
checkArraySize(input_length_);
for (int l = 1; l < input_length_ + 1; l++) {
queue_.clear();
for (int o = 0; o < num_output_symbols_; o++) {
for (int l_start = Math.max(0, l - max_input_segment_length); l_start < l; l_start++) {
double score = model_.getPairScore(instance, l_start, l, o);
if (l_start > 0) {
score += state_array_[(l_start - 1)][0].score;
}
State state = new State();
state.score = score;
state.output = o;
state.index = l_start;
queue_.add(state);
}
}
for (int rank = 0; rank < rank_length_; rank++) {
State state = queue_.poll();
assert state == null || state.index < l;
state_array_[l - 1][rank] = state;
}
}
return backtrace();
}
private Result bySignature(HashableIntArray signature) {
return bySignature(signature, false);
}
private Result bySignature(HashableIntArray signature, boolean debug) {
List<Integer> outputs = new LinkedList<>();
List<Integer> inputs = new LinkedList<>();
int end_index = input_length_;
double score = state_array_[input_length_ - 1][0].score;
int signature_index = 0;
int[] signature_array = signature.getArray();
if (debug)
System.err.println(score);
while (true) {
if (signature_index >= signature_array.length) {
System.err.println(signature);
System.err.println(instance_.getInstance().getForm());
}
int rank = signature_array[signature_index++];
State state = state_array_[end_index - 1][rank];
if (state == null) {
return null;
}
int start_index = state.index;
inputs.add(end_index);
int output = state.output;
outputs.add(output);
// Calculate difference to best score and substract it from current score.
double diff_to_best = state_array_[end_index - 1][0].score
- state.score;
assert diff_to_best >= 0.0;
score = score - diff_to_best;
if (debug)
System.err.println(score + " " + diff_to_best);
if (start_index == 0)
break;
end_index = start_index;
}
// If not all positive rank values of a signature are used
// then the signature is invalid as there is a second signature
// that was produced earlier:
// if the signature is 0 0 1 1, but the last 1 is not used
// then 0 0 1 is identical.
for (int i=signature_index; i<signature_array.length; i++) {
if (signature_array[i] > 0) {
return null;
}
}
Collections.reverse(outputs);
Collections.reverse(inputs);
return new Result(model_, outputs, inputs, instance_.getInstance()
.getForm(), score).setSignature(signature);
}
public List<Result> backtrace() {
List<Result> list = new LinkedList<Result>();
HashableIntArray signature = new HashableIntArray(
new int[input_length_]);
result_queue_.clear();
used_signatures_.clear();
result_queue_.add(bySignature(signature));
used_signatures_.add(signature);
while (list.size() < rank_length_) {
Result result = result_queue_.poll();
if (result == null) {
break;
}
signature = result.getSignature();
int[] signature_array = signature.getArray();
result.setSignature(null);
list.add(result);
for (int index = 0; index < result.getOutputs().size(); index++) {
int new_rank = signature_array[index] + 1;
if (new_rank >= rank_length_)
continue;
int[] new_signature_array = Arrays.copyOf(signature_array,
signature_array.length);
new_signature_array[index] = new_rank;
HashableIntArray new_signature = new HashableIntArray(
new_signature_array);
if (!used_signatures_.contains(new_signature)) {
used_signatures_.add(new_signature);
Result new_result = bySignature(new_signature);
if (new_result != null) {
if (!Numerics.approximatelyLesserEqual(
new_result.getScore(), result.getScore())) {
System.err.println(signature + " " + new_signature);
bySignature(signature, true);
bySignature(new_signature, true);
}
assert Numerics.approximatelyLesserEqual(
new_result.getScore(), result.getScore());
result_queue_.add(new_result);
}
}
}
}
return list;
}
private void checkArraySize(int required_length) {
if (state_array_ == null || state_array_.length < required_length) {
state_array_ = new State[required_length][rank_length_];
}
}
}