// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core.lattice;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.Transition;
import marmot.util.HashableIntArray;
public class SequenceViterbiLattice implements ViterbiLattice {
private LatticeEntry[][][] lattice_;
private List<List<State>> candidates_;
private State boundary_;
private int beam_size_;
private boolean initilized_;
private boolean marginalize_lemmas_;
public SequenceViterbiLattice(List<List<State>> candidates, State boundary, int beam_size, boolean marginalize_lemmas) {
candidates_ = candidates;
boundary_ = boundary;
beam_size_ = beam_size;
initilized_ = false;
marginalize_lemmas_ = marginalize_lemmas;
}
public void init() {
if (initilized_) {
return;
}
initilized_ = true;
lattice_ = new LatticeEntry[candidates_.size()][][];
PriorityQueue<LatticeEntry> queue = new PriorityQueue<LatticeEntry>();
List<State> previous_states = Collections.singletonList(boundary_);
for (int index = 0; index < candidates_.size(); index++) {
List<State> states = candidates_.get(index);
lattice_[index] = new LatticeEntry[states.size()][];
int state_index = 0;
for (State state : states) {
queue.clear();
double state_score = state.getScore();
State zero_order_state = state.getZeroOrderState();
if (zero_order_state.getLemmaCandidates() != null && !marginalize_lemmas_) {
double score = state_score - zero_order_state.getScore() + zero_order_state.getRealScore();
state_score = Double.NEGATIVE_INFINITY;
for (RankerCandidate candidate : zero_order_state.getLemmaCandidates()) {
double candidate_score = score + candidate.getScore();
state_score = Math.max(state_score, candidate_score);
}
}
for (int previous_state_index = 0; previous_state_index < previous_states
.size(); previous_state_index++) {
State transition = state.getTransition(previous_state_index);
if (transition == null) {
continue;
}
double score = state_score + transition.getScore();
if (index > 0) {
score += lattice_[index - 1][previous_state_index][0]
.getScore();
}
queue.add(new LatticeEntry(score, previous_state_index));
}
int length = Math.min(beam_size_, queue.size());
assert length > 0;
lattice_[index][state_index] = new LatticeEntry[length];
for (int rank = 0; rank < length; rank++) {
LatticeEntry entry = queue.poll();
if (entry == null)
break;
lattice_[index][state_index][rank] = entry;
}
state_index ++;
}
previous_states = states;
}
}
public Hypothesis getViterbiSequence() {
init();
int[] signature_array = new int[candidates_.size() - 1];
HashableIntArray signature = new HashableIntArray(signature_array);
return getSequenceBySignature(signature);
}
public Hypothesis getSequenceBySignature(HashableIntArray signature) {
init();
List<Integer> list = new LinkedList<Integer>();
int index = candidates_.size() - 1;
int state_index = 0;
list.add(0);
Double score = null;
int[] signature_array = signature.getArray();
while (index >= 1) {
int rank = signature_array[index - 1];
if (rank >= lattice_[index][state_index].length) {
return null;
}
LatticeEntry entry = lattice_[index][state_index][rank];
if (entry == null) {
return null;
}
if (score == null) {
score = entry.getScore();
}
if (rank != 0) {
score += entry.getScore()
- lattice_[index][state_index][0].getScore();
}
state_index = entry.getPreviousStateIndex();
index--;
list.add(state_index);
}
if (score == null) {
return null;
}
Collections.reverse(list);
return new Hypothesis(list, score, signature);
}
public List<Hypothesis> getNbestSequences() {
init();
List<Hypothesis> list = new LinkedList<Hypothesis>();
HashableIntArray signature = new HashableIntArray(new int[candidates_.size() - 1]);
PriorityQueue<Hypothesis> queue = new PriorityQueue<Hypothesis>();
Set<HashableIntArray> used_signatures = new HashSet<>();
queue.add(getSequenceBySignature(signature));
used_signatures.add(signature);
while (list.size() < beam_size_) {
Hypothesis h = queue.poll();
if (h == null) {
break;
}
list.add(h);
signature = h.getSignature();
int [] signature_array = signature.getArray();
for (int index = 0; index < signature_array.length; index++) {
int[] new_signature_array = new int[signature_array.length];
System.arraycopy(signature_array, 0, new_signature_array, 0,
signature_array.length);
new_signature_array[index]++;
HashableIntArray new_signature = new HashableIntArray(new_signature_array);
if (!used_signatures.contains(new_signature)) {
used_signatures.add(new_signature);
h = getSequenceBySignature(new_signature);
if (h != null) {
queue.add(h);
}
}
}
}
return list;
}
public void findGoldSequence(List<Integer> path) {
init();
assert path.size() == candidates_.size();
assert path.size() == lattice_.length;
for (int index = path.size() - 1; index > 0; index --) {
int state_index = path.get(index);
int real_previous_state_index = path.get(index - 1);
boolean found_index = false;
for (LatticeEntry entry : lattice_[index][state_index]) {
if (entry == null) {
break;
}
int previous_state_index = entry.getPreviousStateIndex();
if (previous_state_index == real_previous_state_index) {
found_index = true;
break;
}
}
if (!found_index)
System.err.format("%s index = %d p_index = %d lattice entries = %s\n", candidates_.get(index).get(state_index), index, real_previous_state_index, Arrays.toString(lattice_[index][state_index]));
}
}
public List<List<State>> prune() {
init();
List<List<State>> candidates = getCandidates();
List<Set<Integer>> candidate_sets = new ArrayList<Set<Integer>>(
candidates.size());
for (int index = 0; index < candidates.size(); index++) {
candidate_sets.add(new HashSet<Integer>());
}
for (Hypothesis h : getNbestSequences()) {
int index = 0;
int previous_state_index = 0;
for (int state_index : h.getStates()) {
int previous_num_candidates = (index - 1 >= 0) ? candidates
.get(index - 1).size() : 1;
candidate_sets.get(index).add(
state_index * previous_num_candidates
+ previous_state_index);
previous_state_index = state_index;
index++;
}
}
List<List<State>> new_candidates = new ArrayList<List<State>>(
candidates.size());
int[] index_map = null;
for (int index = 0; index < candidates.size(); index++) {
Set<Integer> candidate_set = candidate_sets.get(index);
int[] new_index_map = new int[candidates.get(index).size()];
Arrays.fill(new_index_map, -1);
List<State> states = new ArrayList<State>(candidate_set.size());
for (int encoded_indexes : candidate_set) {
int previous_num_candidates = (index - 1 >= 0) ? candidates
.get(index - 1).size() : 1;
int state_index = encoded_indexes / previous_num_candidates;
int previous_state_index = encoded_indexes
% previous_num_candidates;
int new_state_index = new_index_map[state_index];
if (new_state_index < 0) {
new_state_index = states.size();
new_index_map[state_index] = new_state_index;
State state = candidates.get(index).get(state_index);
if (index > 0) {
state = state.copy();
Transition[] new_transitions = new Transition[new_candidates.get(
index - 1).size()];
state.setTransitions(new_transitions);
}
states.add(state);
}
if (index > 0) {
State old_state = candidates.get(index).get(state_index);
State[] transitions = old_state.getTransitions();
State state = states.get(new_state_index);
State[] new_transitions = state.getTransitions();
new_transitions[index_map[previous_state_index]] = transitions[previous_state_index];
}
}
new_candidates.add(states);
index_map = new_index_map;
}
return new_candidates;
}
@Override
public List<List<State>> getCandidates() {
return candidates_;
}
}