// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.splitmerge;
import hmmla.hmm.HmmModel;
import hmmla.hmm.Model;
import hmmla.io.Token;
import hmmla.util.Numerics;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
public class ForwardChart {
protected double a[][];
protected int maxT;
protected int N;
protected int T;
protected HmmModel model;
public void init(int N, HmmModel model) {
maxT = 0;
this.N = N;
this.model = model;
}
public HmmModel getHmmModel() {
return model;
}
protected double _score(int i, int j) {
return model.getTransitions(i, j);
}
public synchronized void update(Iterable<Iterable<Integer>> tags,
List<Token> sentence) {
T = sentence.size() + 1;
if (T == 1) {
return;
}
if (T > maxT) {
a = new double[T][N];
maxT = T;
}
for (int t = 0; t < T; t++) {
Arrays.fill(a[t], Double.NEGATIVE_INFINITY);
}
double[] scores = new double[N];
Iterator<Iterable<Integer>> iterator = tags.iterator();
assert iterator.hasNext();
Iterable<Integer> last_tags = iterator.next();
_score(sentence, 0, scores);
for (Integer i : last_tags) {
a[0][i] = scores[i] + _score(Model.BorderIndex, i);
assert (a[0][i] != Double.NEGATIVE_INFINITY);
}
int t = 1;
while (iterator.hasNext()) {
Iterable<Integer> current_tags = iterator.next();
_score(sentence, t, scores);
for (Integer j : last_tags) {
if (Double.isInfinite(a[t - 1][j])) {
continue;
}
for (Integer i : current_tags) {
double score = a[t - 1][j] + scores[i] + _score(j, i);
a[t][i] = Numerics.sumLogProb(a[t][i], score);
}
}
last_tags = current_tags;
t++;
}
assert t == T - 1;
for (Integer i : last_tags) {
double score = a[t - 1][i] + _score(i, Model.BorderIndex);
a[t][Model.BorderIndex] = Numerics.sumLogProb(score,
a[t][Model.BorderIndex]);
}
}
protected void _score(List<Token> outputs, int t, double[] scores) {
Token token = outputs.get(t);
Arrays.fill(scores, 0.0);
model.getEmissions(token.getWordForm(), scores);
}
public synchronized double score(int index, int tag) {
return a[index][tag];
}
public synchronized double score() {
return score(T - 1, Model.BorderIndex);
}
}