// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core.lattice;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.util.HashableIntArray;
public class ZeroOrderViterbiLattice implements ViterbiLattice {
private LatticeEntry[][] lattice_;
private List<List<State>> candidates_;
private int beam_size_;
private boolean initilized_;
private boolean marginalize_lemmas_;
public ZeroOrderViterbiLattice(List<List<State>> candidates, int beam_size, boolean marginalize_lemmas) {
candidates_ = candidates;
beam_size_ = beam_size;
initilized_ = false;
marginalize_lemmas_ = marginalize_lemmas;
}
public void init() {
if (initilized_) {
return;
}
initilized_ = true;
lattice_ = new LatticeEntry[candidates_.size()][];
PriorityQueue<LatticeEntry> queue = new PriorityQueue<LatticeEntry>();
int index = 0;
for (List<State> states : candidates_) {
queue.clear();
int state_index = 0;
for (State state : states) {
double score = state.getScore();
if (state.getLemmaCandidates() != null && !marginalize_lemmas_) {
RankerCandidate candidate = RankerCandidate.bestCandidate(state.getLemmaCandidates());
score = candidate.getScore() + state.getRealScore();
}
queue.add(new LatticeEntry(score, state_index));
state_index++;
}
int length = Math.min(beam_size_, queue.size());
lattice_[index] = new LatticeEntry[length];
assert length > 0;
for (int rank = 0; rank < length; rank++) {
LatticeEntry entry = queue.poll();
if (entry == null)
break;
lattice_[index][rank] = entry;
}
index++;
}
}
public Hypothesis getViterbiSequence() {
init();
int[] signature_array = new int[candidates_.size()];
HashableIntArray signature = new HashableIntArray(signature_array);
return getSequenceBySignature(signature);
}
public Hypothesis getSequenceBySignature(HashableIntArray signature) {
init();
List<Integer> list = new LinkedList<Integer>();
double score = 0.;
int[] signature_array = signature.getArray();
for (int index = 0; index < signature_array.length; index++) {
int rank = signature_array[index];
if (rank >= lattice_[index].length) {
return null;
}
LatticeEntry entry = lattice_[index][rank];
if (entry == null) {
return null;
}
score += entry.getScore();
list.add(entry.getPreviousStateIndex());
}
return new Hypothesis(list, score, signature);
}
/*
* public List<List<State>> filter() { List<List<State>> candidates =
* getCandidates();
*
* List<Set<Integer>> candidate_sets = new ArrayList<Set<Integer>>(
* candidates.size()); for (int index = 0; index < candidates.size();
* index++) { candidate_sets.add(new HashSet<Integer>()); }
*
* for (Hypothesis h : getNbestSequences()) { int index = 0; for (int
* state_index : h.getStates()) {
* candidate_sets.get(index).add(state_index); index++; } }
*
* List<List<State>> new_candidates = new ArrayList<List<State>>(
* candidates.size()); for (int index = 0; index < candidates.size();
* index++) { Set<Integer> candidate_set = candidate_sets.get(index); int[]
* new_index_map = new int[candidates.get(index).size()];
* Arrays.fill(new_index_map, -1);
*
* List<State> states = new ArrayList<State>(candidate_set.size()); for (int
* state_index : candidate_set) { State state =
* candidates.get(index).get(state_index); int new_state_index =
* states.size(); new_index_map[state_index] = new_state_index;
* states.add(state); } new_candidates.add(states); }
*
* return new_candidates; }
*/
public List<List<State>> prune() {
init();
List<List<State>> candidates = new ArrayList<List<State>>(
candidates_.size());
for (int index = 0; index < candidates_.size(); index++) {
List<State> states = new ArrayList<State>(lattice_[index].length);
for (int rank = 0; rank < lattice_[index].length; rank++) {
LatticeEntry entry = lattice_[index][rank];
int candidate_index = entry.getPreviousStateIndex();
states.add(candidates_.get(index).get(candidate_index));
}
candidates.add(states);
}
assert candidates.size() > 0;
return candidates;
}
public List<Hypothesis> getNbestSequences() {
init();
List<Hypothesis> list = new LinkedList<Hypothesis>();
int[] signature_array = new int[candidates_.size()];
HashableIntArray signature = new HashableIntArray(signature_array);
PriorityQueue<Hypothesis> queue = new PriorityQueue<Hypothesis>();
Set<HashableIntArray> used_signatures = new HashSet<>();
queue.add(getSequenceBySignature(signature));
used_signatures.add(signature);
while (list.size() < beam_size_) {
Hypothesis h = queue.poll();
if (h == null) {
break;
}
list.add(h);
signature = h.getSignature();
for (int index = 0; index < signature_array.length; index++) {
int[] new_signature_array = new int[signature_array.length];
System.arraycopy(signature_array, 0, new_signature_array, 0,
signature_array.length);
new_signature_array[index]++;
HashableIntArray new_signature = new HashableIntArray(
new_signature_array);
if (!used_signatures.contains(new_signature)) {
used_signatures.add(new_signature);
h = getSequenceBySignature(new_signature);
if (h != null) {
queue.add(h);
}
}
}
}
return list;
}
@Override
public List<List<State>> getCandidates() {
return candidates_;
}
}