// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.splitmerge;
import hmmla.hmm.HmmModel;
import hmmla.hmm.Model;
import hmmla.hmm.Statistics;
import hmmla.hmm.Tree;
import hmmla.io.Sentence;
import hmmla.io.Token;
import hmmla.util.Numerics;
import hmmla.util.SymbolTable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class SimpleEmTrainer extends EmTrainer {
protected ForwardChart forward_;
protected BackwardChart backward_;
protected Model model_;
public double estep(Model model, HmmModel hmm_model,
Iterable<Sentence> reader) {
return estep(model, hmm_model, reader, true);
}
public double estep(Model model, HmmModel hmm_model,
Iterable<Sentence> reader, boolean update) {
reset(model, hmm_model);
if (update)
model.getStatistics().setZero();
double ll = 0.0;
for (Sentence sentence : reader) {
ll += estep(sentence, update);
}
return ll;
}
protected void reset(Model model, HmmModel hmm_model) {
SymbolTable<String> tag_table = model.getTagTable();
if (forward_ == null)
forward_ = new ForwardChart();
if (backward_ == null)
backward_ = new BackwardChart();
forward_.init(tag_table.size(), hmm_model);
backward_.init(tag_table.size(), hmm_model);
model_ = model;
}
protected double estep(Sentence sentence, boolean update) {
Map<String, Tree> topLevel = model_.getTopLevel();
SymbolTable<String> tag_table = model_.getTagTable();
int T = sentence.size();
List<Iterable<Integer>> candidates = new ArrayList<Iterable<Integer>>(T);
for (Token token : sentence) {
List<Tree> leaves = new LinkedList<Tree>();
Tree tree = topLevel.get(token.getTag());
tree.getLeaves(leaves);
List<Integer> ileaves = new LinkedList<Integer>();
for (Tree leaf : leaves) {
ileaves.add(tag_table.toIndex(leaf.getName()));
}
candidates.add(ileaves);
}
assert candidates.size() == sentence.size();
forward_.update(candidates, sentence);
Collections.reverse(candidates);
backward_.update(candidates, sentence);
Collections.reverse(candidates);
double logZ = forward_.score();
addStateScores(model_, sentence, candidates, logZ, update);
addTransitionScores(model_, sentence, candidates, logZ, update);
return logZ;
}
private void addTransitionScores(Model model, List<Token> sentence,
List<Iterable<Integer>> candidates, double logZ, boolean update) {
Statistics statistics = model.getStatistics();
HmmModel hmm_model = forward_.getHmmModel();
int t = 0;
Iterator<Iterable<Integer>> iterator = candidates.iterator();
assert iterator.hasNext();
Iterable<Integer> last_tags = iterator.next();
Token token = sentence.get(t);
double newLogZ = Double.NEGATIVE_INFINITY;
double[] scores = new double[model.getTagTable().size()];
Arrays.fill(scores, 0.0);
hmm_model.getEmissions(token.getWordForm(), scores);
for (Integer tag : last_tags) {
double score = scores[tag] +
+ hmm_model.getTransitions(Model.BorderIndex, tag)
+ backward_.score(t, tag);
double p = Math.exp(score - logZ);
if (update)
statistics.addTransitions(Model.BorderIndex, tag, p);
newLogZ = Numerics.sumLogProb(newLogZ, score);
}
t++;
while (iterator.hasNext()) {
Iterable<Integer> current_tags = iterator.next();
token = sentence.get(t);
Arrays.fill(scores, 0.0);
hmm_model.getEmissions(token.getWordForm(), scores);
newLogZ = Double.NEGATIVE_INFINITY;
for (Integer fromIndex : last_tags) {
for (Integer toIndex : current_tags) {
if (toIndex == Model.BorderIndex) {
continue;
}
double score = forward_.score(t - 1, fromIndex)
+ scores[toIndex]
+ hmm_model.getTransitions(fromIndex,
toIndex) + backward_.score(t, toIndex);
newLogZ = Numerics.sumLogProb(newLogZ, score);
double p = Math.exp(score - logZ);
if (update)
statistics.addTransitions(fromIndex, toIndex, p);
}
}
last_tags = current_tags;
t++;
}
newLogZ = Double.NEGATIVE_INFINITY;
for (int from_index : last_tags) {
double score = forward_.score(t - 1, from_index)
+ hmm_model.getTransitions(from_index, Model.BorderIndex);
double p = Math.exp(score - logZ);
if (update)
statistics.addTransitions(from_index, Model.BorderIndex, p);
newLogZ = Numerics.sumLogProb(newLogZ, score);
}
}
private void addStateScores(Model model, List<Token> outputs,
Iterable<Iterable<Integer>> candidates, double logZ, boolean update) {
Statistics statistics = model.getStatistics();
int t = 0;
for (Iterable<Integer> candidate : candidates) {
Token output = outputs.get(t);
int ioutput = model.getWordTable().toIndex(output.getWordForm());
double newLogZ = Double.NEGATIVE_INFINITY;
for (int index : candidate) {
double score = forward_.score(t, index)
+ backward_.score(t, index);
double p = Math.exp(score - logZ);
if (update) {
statistics.addEmissions(index, ioutput, p);
}
newLogZ = Numerics.sumLogProb(newLogZ, score);
}
t++;
}
}
}