// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core.lattice;
import java.util.Collections;
import java.util.List;
import marmot.core.State;
import marmot.util.Numerics;
public class ForwardSequenceLattice {
private double[][] lattice_;
private List<List<State>> candidates_;
private State boundary_;
public ForwardSequenceLattice(List<List<State>> candidates, State boundary) {
candidates_ = candidates;
boundary_ = boundary;
}
public void init() {
lattice_ = new double[candidates_.size()][];
List<State> previous_states = Collections.singletonList(boundary_);
for (int index = 0; index < candidates_.size(); index++) {
List<State> states = candidates_.get(index);
lattice_[index] = new double[states.size()];
int state_index = 0;
for (State state : states) {
double score_sum = Double.NEGATIVE_INFINITY;
for (int previous_state_index = 0; previous_state_index < previous_states
.size(); previous_state_index++) {
State transition = state.getTransition(previous_state_index);
if (transition == null) {
continue;
}
double score = state.getScore() + transition.getScore();
if (index > 0) {
score += lattice_[index - 1][previous_state_index];
}
score_sum = Numerics.sumLogProb(score_sum, score);
}
lattice_[index][state_index] = score_sum;
state_index ++;
}
previous_states = states;
}
assert lattice_[candidates_.size() - 1].length == 1;
}
double partitionFunction() {
return lattice_[candidates_.size() - 1][0];
}
public double get(int index, int state_index) {
if (index == -1) {
return 0;
}
return lattice_[index][state_index];
}
}