// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.hmm;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public class WbSmoother extends Smoother {
private static final long serialVersionUID = 1L;
protected void smooth(Tree tag, Model model, Statistics statistics, Type type) {
int number = getNumber(model, type);
double[] freqs = new double[number];
collectFreqs(tag, model, type, freqs);
smooth(tag, model, statistics, type, freqs);
}
private void collectFreqs(Tree tag, Model model, Type type, double[] freqs) {
Statistics unsmoothed_statistics = model.getStatistics();
List<Tree> leaves = new LinkedList<Tree>();
tag.getLeaves(leaves);
Arrays.fill(freqs, 0.0);
for (Tree leaf : leaves) {
int tag_index = model.getTagTable()
.toIndex(leaf.getName(), false);
for (int index = 0; index < freqs.length; index++) {
freqs[index] += getFreq(unsmoothed_statistics, tag_index,
index, type);
}
}
}
private void smooth(Tree tag, Model model, Statistics statistics,
Type type, double[] freqs) {
int number = freqs.length;
if (tag.getLeft() == null && tag.getRight() == null) {
int tag_index = model.getTagTable().toIndex(tag.getName());
for (int index = 0; index < number; index++) {
setFreq(model, statistics, tag_index, index, type, freqs[index]);
}
return;
}
double[] fine_freqs = new double[number];
if (tag.getLeft() == null || tag.getRight() == null) {
Tree child_tag;
if (tag.getLeft() == null) {
child_tag = tag.getRight();
} else {
child_tag = tag.getLeft();
}
collectFreqs(child_tag, model, type, fine_freqs);
smooth(freqs, fine_freqs, type);
smooth(child_tag, model, statistics, type, fine_freqs);
return;
}
collectFreqs(tag.getLeft(), model, type, fine_freqs);
smooth(freqs, fine_freqs, type);
smooth(tag.getLeft(), model, statistics, type, fine_freqs);
collectFreqs(tag.getRight(), model, type, fine_freqs);
smooth(freqs, fine_freqs, type);
smooth(tag.getRight(), model, statistics, type, fine_freqs);
}
private double computeBackoffFactor(double[] freqs) {
double backoff_factor = 0.0;
for (double freq : freqs) {
backoff_factor += Math.min(1., freq);
}
return backoff_factor + 1.;
}
private void smooth(double[] freqs, double[] fine_freqs, Type type) {
double total = 0.0;
for (double freq : freqs) {
total += freq;
}
double total_fine = 0.0;
for (double freq : fine_freqs) {
total_fine += freq;
}
boolean uniform_backoff_prob = false;
if (total < 1e-10) {
uniform_backoff_prob = true;
}
double backoff_factor = computeBackoffFactor(fine_freqs);
for (int index = 0; index < fine_freqs.length; index++) {
double backoff_prob;
if (uniform_backoff_prob) {
backoff_prob = 1. / (double) fine_freqs.length;
} else {
backoff_prob = freqs[index] / total;
}
double freq = fine_freqs[index] + (backoff_factor * backoff_prob);
double prob = freq / (total_fine + backoff_factor);
fine_freqs[index] = prob * total_fine;
}
}
}