// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.decode;
import hmmla.hmm.HmmModel;
import hmmla.hmm.HmmTrainer;
import hmmla.hmm.Model;
import hmmla.hmm.Statistics;
import hmmla.hmm.Tree;
import hmmla.io.Sentence;
import hmmla.io.Token;
import hmmla.splitmerge.BackwardChart;
import hmmla.splitmerge.ForwardChart;
import hmmla.util.Counter;
import hmmla.util.Numerics;
import hmmla.util.SymbolTable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Map.Entry;
public class CoarseToFineDecoder implements Decoder {
private Model[] models_;
private HmmTrainer trainer_;
private ForwardChart[] forwards_;
private BackwardChart[] backwards_;
private boolean decode_toplevel_;
private boolean product_;
public CoarseToFineDecoder(Model model, HmmTrainer trainer,
boolean decode_toplevel, boolean product) {
trainer_ = trainer;
decode_toplevel_ = decode_toplevel;
product_ = product;
init(model);
}
protected void init(Model model) {
int max_level = model.getLevel();
models_ = new Model[max_level + 1];
forwards_ = new ForwardChart[max_level + 1];
backwards_ = new BackwardChart[max_level + 1];
List<Tree> trees = new LinkedList<Tree>();
models_[max_level] = model;
for (int level = max_level - 1; level >= 0; level--) {
Model new_model = new Model(model);
// Get Tags at current level
trees.clear();
for (Tree tree : model.getTopLevel().values()) {
tree.getChildrenWithLevel(trees, level);
}
SymbolTable<String> tag_table = new SymbolTable<String>();
tag_table.toIndex(Model.BorderSymbol, true);
Map<String, Tree> clustering = new HashMap<String, Tree>();
for (Tree tree : trees) {
clustering.put(tree.getName(), tree);
tag_table.toIndex(tree.getName(), true);
}
Statistics statistics = new Statistics(tag_table.size(), model
.getWordTable().size());
new_model.setClustering(clustering);
new_model.setTagTable(tag_table);
new_model.setStatistics(statistics);
collectStatistics(new_model, model);
models_[level] = new_model;
model = new_model;
}
for (int level = max_level; level >= 0; level--) {
model = models_[level];
HmmModel hmm_model = trainer_.train(model);
SymbolTable<String> tag_table = model.getTagTable();
forwards_[level] = new ForwardChart();
forwards_[level].init(tag_table.size(), hmm_model);
backwards_[level] = new BackwardChart();
backwards_[level].init(tag_table.size(), hmm_model);
}
}
private void collectStatistics(Model new_model, Model model) {
Statistics new_statistics = new_model.getStatistics();
Statistics statistics = model.getStatistics();
SymbolTable<String> new_tag_table = new_model.getTagTable();
SymbolTable<String> tag_table = model.getTagTable();
for (Entry<String, Integer> entry : tag_table.entrySet()) {
int index = entry.getValue();
String name = entry.getKey();
int new_index = getNewIndex(name, new_tag_table,
model.getClustering());
for (int output = 0; output < statistics.getNumOutputs(); output++) {
double f = statistics.getEmissions(index, output);
new_statistics.addEmissions(new_index, output, f);
}
for (Entry<String, Integer> entry2 : tag_table.entrySet()) {
int index2 = entry2.getValue();
String name2 = entry2.getKey();
int new_index2 = getNewIndex(name2, new_tag_table,
model.getClustering());
double f = statistics.getTransitions(index, index2);
new_statistics.addTransitions(new_index, new_index2, f);
}
}
}
private int getNewIndex(String name, SymbolTable<String> tag_table,
Map<String, Tree> clustering) {
try {
return tag_table.toIndex(name);
} catch (NoSuchElementException e) {
}
Tree tree = clustering.get(name);
String parent_name = tree.getParent().getName();
try {
return tag_table.toIndex(parent_name);
} catch (NoSuchElementException e2) {
throw new IllegalStateException("Inconsistent Tree hierarchy!");
}
}
@Override
public List<String> bestPath(Sentence sentence) {
// We want indexes for the coarse model!
SymbolTable<String> tag_table = models_[0].getTagTable();
List<Iterable<Integer>> candidates = getInitialCandidates(sentence, tag_table);
for (int level = 0; level < models_.length; level++) {
ForwardChart forward = forwards_[level];
BackwardChart backward = backwards_[level];
forward.update(candidates, sentence);
Collections.reverse(candidates);
backward.update(candidates, sentence);
Collections.reverse(candidates);
double logZ = forward.score();
// Get new candidates!
if (level < models_.length - 1) {
tag_table = models_[level]
.getTagTable();
SymbolTable<String> next_tag_table = models_[level + 1]
.getTagTable();
for (int t = 0; t < sentence.size(); t++) {
List<Integer> new_candidates = new LinkedList<Integer>();
for (Integer index : candidates.get(t)) {
double log_prob = forward.score(t, index)
+ backward.score(t, index) - logZ;
double prob = Math.exp(log_prob);
if (prob > 1e-2) {
String name = tag_table.toSymbol(index);
Tree tree = models_[level].getClustering()
.get(name);
if (tree.getLevel() >= level + 1) {
int new_index = next_tag_table.toIndex(name);
new_candidates.add(new_index);
} else {
String child_name = tree.getLeft().getName();
int new_index = next_tag_table
.toIndex(child_name);
new_candidates.add(new_index);
child_name = tree.getRight().getName();
new_index = next_tag_table
.toIndex(child_name);
new_candidates.add(new_index);
}
}
}
candidates.set(t, new_candidates);
}
}
}
if (product_) {
return bestProductPath(models_[models_.length - 1],
forwards_[models_.length - 1],
backwards_[models_.length - 1], candidates, sentence);
}
return bestPath(models_[models_.length - 1],
forwards_[models_.length - 1], backwards_[models_.length - 1],
candidates, sentence.size());
}
private List<String> bestPath(Model model, ForwardChart forward,
BackwardChart backward, List<Iterable<Integer>> candidates, int T) {
SymbolTable<String> tag_table = model.getTagTable();
Map<String, Tree> clustering = model.getClustering();
List<String> path = new ArrayList<String>(T);
for (int t = 0; t < T; t++) {
Counter<String> counter = new Counter<String>(
Double.NEGATIVE_INFINITY);
double bestP = Double.NEGATIVE_INFINITY;
String bestName = null;
for (Integer candidate : candidates.get(t)) {
String name = tag_table.toSymbol(candidate);
double prob = forward.score(t, candidate)
+ backward.score(t, candidate);
if (decode_toplevel_) {
Tree tree = clustering.get(name);
Tree parent = tree.getRoot();
assert parent != null;
Double cached_prob = counter.count(parent.getName());
prob = Numerics.sumLogProb(prob, cached_prob);
counter.set(parent.getName(), prob);
name = parent.getName();
}
if (prob > bestP) {
bestP = prob;
bestName = name;
}
}
if (bestName == null) {
throw new RuntimeException("Didn't find candidate!");
}
path.add(bestName);
}
return path;
}
private List<String> bestProductPath(Model model, ForwardChart forward,
BackwardChart backward, List<Iterable<Integer>> candidates,
List<Token> sentence) {
double logZ = forward.score();
// Calculate posterior probabilities.
double[][] state_scores = new double[sentence.size()][];
double[][][] transition_scores = new double[sentence.size()][][];
SymbolTable<String> tag_table = model.getTagTable();
List<SymbolTable<String>> tables = new ArrayList<SymbolTable<String>>(
sentence.size());
HmmModel hmm_model = forward.getHmmModel();
double[] emmission_scores = new double[model.getTagTable().size()];
for (int t = 0; t < sentence.size(); t++) {
List<Integer> current_candidates = (List<Integer>) candidates
.get(t);
tables.add(new SymbolTable<String>());
state_scores[t] = new double[current_candidates.size()];
if (t > 0) {
List<Integer> last_candidates = (List<Integer>) candidates
.get(t - 1);
transition_scores[t] = new double[last_candidates.size()][current_candidates
.size()];
Arrays.fill(emmission_scores, 0.0);
hmm_model.getEmissions(sentence.get(t).getWordForm(), emmission_scores);
}
for (int candidate : candidates.get(t)) {
String name = tag_table.toSymbol(candidate);
Tree tree = model.getClustering().get(name);
String parent_name = tree.getRoot().getName();
int index = tables.get(t).toIndex(parent_name, true);
double log_prob = forward.score(t, candidate)
+ backward.score(t, candidate) - logZ;
double prob = Math.exp(log_prob);
assert prob > 0.0 && prob <= 1.0;
state_scores[t][index] += prob;
if (t > 0) {
for (int last_candidate : candidates.get(t - 1)) {
String last_name = tag_table.toSymbol(last_candidate);
Tree last_tree = model.getClustering().get(last_name);
String last_root_name = last_tree.getRoot().getName();
int last_index = tables.get(t - 1).toIndex(
last_root_name);
log_prob = forward.score(t - 1, last_candidate)
+ emmission_scores[candidate]
+ hmm_model.getTransitions(last_candidate,
candidate)
+ backward.score(t, candidate) - logZ;
prob = Math.exp(log_prob);
assert prob > 0.0 && prob <= 1.0;
transition_scores[t][last_index][index] += prob;
}
}
}
}
// Fill Viterbi charts.
double chart[][] = new double[sentence.size()][];
int backtrace[][] = new int[sentence.size()][];
for (int t = 0; t < sentence.size(); t++) {
chart[t] = new double[tables.get(t).size()];
backtrace[t] = new int[tables.get(t).size()];
Arrays.fill(backtrace[t], -1);
if (t > 0) {
for (int last_index = 0; last_index < tables.get(t - 1).size(); last_index++) {
for (int index = 0; index < tables.get(t).size(); index++) {
assert chart[t - 1][last_index] > 0.0;
double prob = chart[t - 1][last_index]
* transition_scores[t][last_index][index];
if (prob > chart[t][index]) {
chart[t][index] = prob;
backtrace[t][index] = last_index;
}
}
}
} else {
Arrays.fill(chart[0], 1.);
}
for (int index = 0; index < tables.get(t).size(); index++) {
chart[t][index] *= state_scores[t][index];
}
}
// Backtrace
List<String> path = new ArrayList<String>(sentence.size());
double best_prob = 0.0;
int best_index = -1;
int T = sentence.size();
for (int index = 0; index < tables.get(T - 1).size(); index++) {
double prob = chart[T - 1][index];
if (prob > best_prob) {
best_prob = prob;
best_index = index;
}
}
if (best_index == -1)
throw new RuntimeException();
path.add(tables.get(T - 1).toSymbol(best_index));
for (int t = T - 1; t > 0; t--) {
best_index = backtrace[t][best_index];
if (best_index == -1)
throw new RuntimeException();
path.add(tables.get(t - 1).toSymbol(best_index));
}
Collections.reverse(path);
return path;
}
protected List<Iterable<Integer>> getInitialCandidates(List<Token> sentence, SymbolTable<String> tag_table) {
List<Iterable<Integer>> candidates = new ArrayList<Iterable<Integer>>(
sentence.size());
// Only the last model (the original model) contains the candidate maps!
Model model = models_[models_.length - 1];
for (int t = 0; t < sentence.size(); t++) {
String word_form = sentence.get(t).getWordForm();
List<String> string_candidates = model.getCandidates(word_form);
candidates.add(tag_table.toIndexes(string_candidates));
}
return candidates;
}
}