// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla.hmm; import java.util.LinkedList; import java.util.List; public class LinearSmoother extends Smoother { private static final long serialVersionUID = 1L; private double param_; public LinearSmoother(double param) { param_ = param; } protected void smooth(Tree tag, Model model, Statistics statistics, Type type) { double[] backoff = collectFreqs(tag, model, type); List<Tree> leaves = new LinkedList<Tree>(); tag.getLeaves(leaves); for (Tree tree : leaves) { String tag_name = tree.getName(); int tag_index = model.getTagTable().toIndex(tag_name); smooth(tag_index, backoff, model, statistics, type); } } private void smooth(int tag_index, double[] backoff, Model model, Statistics statistics, Type type) { int number = getNumber(model, type); double total = 0; double total_backoff = 0; for (int index = 0; index < number; index ++) { total += getFreq(model.getStatistics(), tag_index, index, type); total_backoff += backoff[index]; } if (total < 1.e-20) { return; } assert (total_backoff > 1.e-20); for (int index = 0; index < number; index ++) { double prob = getFreq(model.getStatistics(), tag_index, index, type) / total; double backoff_prob = backoff[index] / total_backoff; prob = (1 - param_) * prob + (param_* backoff_prob); double freq = prob * total; assert !(Double.isNaN(freq) || Double.isInfinite(freq)); setFreq(model, statistics, tag_index, index, type, freq); } } private double[] collectFreqs(Tree tag, Model model, Type type) { int number = getNumber(model, type); List<Tree> leaves = new LinkedList<Tree>(); double[] freq = new double[number]; tag.getLeaves(leaves); for (Tree tree : leaves) { String tag_name = tree.getName(); int tag_index = model.getTagTable().toIndex(tag_name); for (int index = 0; index < number; index ++) { freq[index] += getFreq(model.getStatistics(), tag_index, index, type); } } return freq; } }