// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.hmm;
import hmmla.util.SymbolTable;
import java.util.Map.Entry;
public class SimpleHmmTrainer implements HmmTrainer {
private double delta_e;
private double delta_t;
public SimpleHmmTrainer(double delta_t, double delta_e) {
this.delta_e = delta_e;
this.delta_t = delta_t;
}
@Override
public HmmModel train(Model model) {
int num_tags = model.getTagTable().size();
int num_outputs = model.getWordTable().size();
Statistics normalized_stats = new Statistics(num_tags, num_outputs);
setTransmissionProbabilities(model, normalized_stats);
setEmissionProbabilities(model, normalized_stats);
return new SimpleHmmModel(normalized_stats, model);
}
protected void setTransmissionProbabilities(Model model,
Statistics normalized_stats) {
SymbolTable<String> tags = model.getTagTable();
Statistics statistics = model.getStatistics();
int num_tags = tags.size();
for (int fromIndex = 0; fromIndex < num_tags; fromIndex++) {
double prior = 0.0;
for (int toIndex = 0; toIndex < num_tags; toIndex++) {
prior += statistics.getTransitions(fromIndex, toIndex);
}
prior = Math.log(prior + delta_t * num_tags);
for (int toIndex = 0; toIndex < num_tags; toIndex++) {
double p = -prior
+ Math.log(statistics
.getTransitions(fromIndex, toIndex) + delta_t);
normalized_stats.setTransitions(fromIndex, toIndex, p);
}
}
}
protected void setEmissionProbabilities(Model model,
Statistics normalized_stats) {
SymbolTable<String> tag_table = model.getTagTable();
SymbolTable<String> word_table = model.getWordTable();
Statistics statistics = model.getStatistics();
int num_outputs = word_table.size();
int num_tags = tag_table.size();
for (int tag = 0; tag < num_tags; tag++) {
if (tag == Model.BorderIndex) {
for (Entry<String, Integer> entry : word_table.entrySet()) {
normalized_stats.setEmissions(tag, entry.getValue(),
Double.NEGATIVE_INFINITY);
}
continue;
}
double total = 0;
for (int output = 0; output < num_outputs; output ++) {
total += statistics.getEmissions(tag, output) + delta_e;
}
for (int output = 0; output < num_outputs; output ++) {
double freq = statistics.getEmissions(tag, output) + delta_e;
double prob = freq / total;
normalized_stats.setEmissions(tag, output, Math.log(prob));
}
}
}
}