// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.hmm;
import hmmla.util.Ling;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public class SignatureHmmTrainer implements HmmTrainer {
private SimpleHmmTrainer trainer_;
private double delta_e_;
private double delta_t_;
public SignatureHmmTrainer(double delta_t, double delta_e) {
trainer_ = new SimpleHmmTrainer(delta_t, delta_e);
delta_e_ = delta_e;
delta_t_ = delta_t;
}
public Map<String, double[]> getUnknownClassProbs(Model model) {
Map<String, double[]> map = new HashMap<String, double[]>();
Statistics statistics = model.getStatistics();
for (Entry<String, Integer> entry : model.getWordTable().entrySet()) {
String output = entry.getKey();
String signature = Ling.signature(output, model);
double[] freqs = map.get(signature);
if (freqs == null) {
freqs = new double[statistics.getNumTags()];
map.put(signature, freqs);
}
for (int index = 0; index < freqs.length; index++) {
freqs[index] += statistics
.getEmissions(index, entry.getValue());
}
}
for (Entry<String, double[]> entry : map.entrySet()) {
double total = 0.0;
double[] freqs = entry.getValue();
for (int tag = 0; tag < statistics.getNumTags(); tag ++) {
freqs[tag] += delta_e_;
total += freqs[tag];
}
for (int tag = 0; tag < statistics.getNumTags(); tag ++) {
freqs[tag] /= total;
}
}
return map;
}
@Override
public HmmModel train(Model model) {
SimpleHmmModel hmm_model = (SimpleHmmModel) trainer_.train(model);
Map<String, double[]> signature_map = getUnknownClassProbs(model);
smoothEmissionProbs(hmm_model.getStatistics(), model, signature_map);
return new SignatureHmmModel(hmm_model, signature_map, model);
}
private double[] getTagPrior(Statistics statistics) {
double[] tag_prior = new double[statistics.getNumTags()];
double total_freq = 0;
for (int tag = 0; tag < statistics.getNumTags(); tag++) {
for (int tag2 = 0; tag2 < statistics.getNumTags(); tag2++) {
tag_prior[tag] += statistics.getTransitions(tag, tag2) + delta_t_;
}
total_freq += tag_prior[tag];
}
assert total_freq > 0;
for (int tag = 0; tag < statistics.getNumTags(); tag++) {
tag_prior[tag] /= total_freq;
assert tag_prior[tag] > 0;
}
return tag_prior;
}
private void smoothEmissionProbs(Statistics output_statistics, Model model,
Map<String, double[]> signature_map) {
Statistics statistics = model.getStatistics();
double[] tag_prior = getTagPrior(statistics);
for (Map.Entry<String, Integer> form_entry : model.getWordTable()
.entrySet()) {
String word_form = form_entry.getKey();
String signature = Ling.signature(word_form, model);
int word_index = form_entry.getValue();
double[] backoff_log_probs = signature_map.get(signature);
assert backoff_log_probs != null;
double backoff_factor = 0;
double total_freq = 0;
if (!model.isKnown(word_form)) {
// This might happen during Jacknife training.
continue;
}
for (int tag = 0; tag < statistics.getNumTags(); tag++) {
double freq = statistics.getEmissions(tag, word_index);
total_freq += freq;
backoff_factor += (freq > 1.) ? 1. : freq;
}
if (total_freq == 0. && backoff_factor == 0.) {
// This might happen during EM-training with sampling
// A word has been seen in the training set, but
// not in the current sample.
// We just assign it the full backoff probability.
backoff_factor = 1.;
}
assert backoff_factor > 0;
double prob_sum = 0;
for (int tag = 0; tag < statistics.getNumTags(); tag++) {
double freq = statistics.getEmissions(tag, word_index);
double backoff_prob = backoff_log_probs[tag];
assert backoff_prob > 0;
double prob = (freq + backoff_factor * backoff_prob) / (total_freq + backoff_factor);
prob_sum += prob;
prob /= tag_prior[tag];
double log_prob = Math.log(prob);
assert log_prob != Double.NEGATIVE_INFINITY;
output_statistics.setEmissions(tag, word_index, log_prob);
}
assert Math.abs(prob_sum - 1.0) < 1e-6;
}
for (Entry<String, double[]> entry : signature_map.entrySet()) {
for (int tag = 0; tag < statistics.getNumTags(); tag ++) {
entry.getValue()[tag] = Math.log(entry.getValue()[tag] / tag_prior[tag]);
}
}
}
}