// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.decode;
import hmmla.hmm.HmmModel;
import hmmla.hmm.Model;
import hmmla.hmm.Tree;
import hmmla.io.Sentence;
import hmmla.splitmerge.BackwardChart;
import hmmla.splitmerge.ForwardChart;
import hmmla.util.Counter;
import hmmla.util.Numerics;
import hmmla.util.SymbolTable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class SimpleDecoder implements Decoder {
protected ForwardChart forward_;
protected BackwardChart backward_;
protected SymbolTable<String> tag_table_;
protected SymbolTable<String> outputTable_;
protected Map<String, Tree> clustering_;
protected Map<String, Tree> top_level_;
private Model model_;
private boolean decode_top_level_;
public SimpleDecoder(Model model, HmmModel hmm_model) {
this(model, hmm_model, true);
}
public SimpleDecoder(Model model, HmmModel hmm_model,
boolean top_level) {
model_ = model;
clustering_ = model.getClustering();
tag_table_ = model.getTagTable();
outputTable_ = model.getWordTable();
forward_ = new ForwardChart();
forward_.init(tag_table_.size(), hmm_model);
backward_ = new BackwardChart();
backward_.init(tag_table_.size(), hmm_model);
decode_top_level_ = top_level;
top_level_ = model.getTopLevel();
}
public List<String> bestPath(Sentence sentence) {
List<Iterable<Integer>> candidates = model_.getSentenceCandidates(sentence);
return bestPath_(candidates, sentence);
}
private List<String> bestPath_(List<Iterable<Integer>> candidates,
Sentence sentence) {
forward_.update(candidates, sentence);
Collections.reverse(candidates);
backward_.update(candidates, sentence);
Collections.reverse(candidates);
List<String> path = new ArrayList<String>(sentence.size());
for (int t = 0; t < sentence.size(); t++) {
Counter<String> counter = new Counter<String>(
Double.NEGATIVE_INFINITY);
double bestP = Double.NEGATIVE_INFINITY;
String bestName = null;
for (Integer i : candidates.get(t)) {
String name = tag_table_.toSymbol(i);
if (i == Model.BorderIndex) {
continue;
}
double prob = forward_.score(t, i) + backward_.score(t, i);
if (decode_top_level_) {
Tree tree = clustering_.get(name);
Tree parent = tree.getRoot();
assert parent != null;
Double cached_prob = counter.count(parent.getName());
prob = Numerics.sumLogProb(prob, cached_prob);
counter.set(parent.getName(), prob);
name = parent.getName();
}
if (prob > bestP) {
bestP = prob;
bestName = name;
}
}
if (bestName == null) {
throw new RuntimeException("Didn't find candidate!");
}
path.add(bestName);
}
return path;
}
public List<String> bestPath(List<String> candidates, Sentence outputs) {
List<Iterable<Integer>> icandidates = new ArrayList<Iterable<Integer>>(
outputs.size());
List<Tree> leaves = new LinkedList<Tree>();
for (String candidate : candidates) {
leaves.clear();
List<Integer> candidate_tags = new LinkedList<Integer>();
Tree tree = top_level_.get(candidate);
if (tree == null) {
throw new RuntimeException(String.format(
"Tree is null! candidate: %s", candidate));
}
tree.getLeaves(leaves);
for (Tree leaf : leaves) {
Integer tag_index = tag_table_.toIndex(leaf.getName());
candidate_tags.add(tag_index);
}
icandidates.add(candidate_tags);
}
return bestPath_(icandidates, outputs);
}
}