// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla.splitmerge; import hmmla.hmm.HmmModel; import hmmla.hmm.HmmTrainer; import hmmla.hmm.Model; import hmmla.hmm.Statistics; import hmmla.hmm.Tree; import hmmla.io.Sentence; import hmmla.io.Token; import hmmla.util.Numerics; import hmmla.util.SymbolTable; import hmmla.util.Tuple; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Map; public class ApproximativeLossEstimator implements LossEstimator { ForwardChart forward_; BackwardChart backward_; HmmTrainer trainer_; public ApproximativeLossEstimator(HmmTrainer trainer) { trainer_ = trainer; } private void calcPrior(Model model, double[] prior) { SymbolTable<String> tagTable = model.getTagTable(); Statistics statistics = model.getStatistics(); for (int i = 0; i < tagTable.size(); i++) { for (int j = 0; j < tagTable.size(); j++) { prior[j] += statistics.getTransitions(i, j); } } } @Override public void estimateLosses(Model model, Iterable<Sentence> reader, List<Tuple<Integer, Double>> tuples) { SymbolTable<String> tagTable = model.getTagTable(); int N = (tagTable.size() - 1) / 2; double[] loss = new double[1 + N]; double[] prior = new double[tagTable.size()]; calcPrior(model, prior); HmmModel hmm_model = trainer_.train(model); if (forward_ == null) { forward_ = new ForwardChart(); } forward_.init(tagTable.size(), hmm_model); if (backward_ == null) { backward_ = new BackwardChart(); } backward_.init(tagTable.size(), hmm_model); for (Sentence sentence : reader) { addLoss(model, sentence, loss, prior); } for (int i = 1; i < N + 1; i++) { tuples.add(new Tuple<Integer, Double>(i, loss[i])); } } protected void addLoss(Model model, Sentence sentence, double[] loss, double[] prior) { SymbolTable<String> tagTable = model.getTagTable(); Map<String, Tree> topLevel = model.getTopLevel(); int T = sentence.size(); List<Iterable<Integer>> candidates = new ArrayList<Iterable<Integer>>(T); List<Iterable<Tree>> parents = new ArrayList<Iterable<Tree>>(T); for (Token token : sentence) { List<Tree> leaves = new LinkedList<Tree>(); Tree tree = topLevel.get(token.getTag()); tree.getTreesOverLeaves(leaves); List<Integer> ileaves = new ArrayList<Integer>(leaves.size() * 2); List<Tree> cparents = new ArrayList<Tree>(leaves.size()); for (Tree leaf : leaves) { Tree left = leaf.getLeft(); String lname = left.getName(); ileaves.add(tagTable.toIndex(lname)); Tree right = leaf.getRight(); String rname = right.getName(); ileaves.add(tagTable.toIndex(rname)); cparents.add(leaf); } parents.add(cparents); candidates.add(ileaves); } forward_.update(candidates, sentence); Collections.reverse(candidates); backward_.update(candidates, sentence); Collections.reverse(candidates); addStateLoss(model, loss, prior, parents, sentence); } protected void addStateLoss(Model model, double[] loss, double[] prior, List<Iterable<Tree>> parents, List<Token> sentence) { double[] scores = new double[model.getTagTable().size()]; SymbolTable<String> tag_table = model.getTagTable(); HmmModel normalizedStatistics = forward_.getHmmModel(); int t = 0; for (Iterable<Tree> cparents : parents) { for (Tree parent : cparents) { // It is quite inefficient to this computation here: (we just // have to do it once per POS) Tree root = parent.getRoot(); assert root != null; double sum = Double.NEGATIVE_INFINITY; List<Tree> leaves = new LinkedList<Tree>(); root.getLeaves(leaves); for (Tree leaf : leaves) { int leaf_index = tag_table.toIndex(leaf.getName()); double f = forward_.score(t, leaf_index); double b = backward_.score(t, leaf_index); sum = Numerics.sumLogProb(sum, f + b); } sum = Math.exp(sum); assert sum + 1e-5 > 0.; int lindex = tag_table.toIndex(parent.getLeft().getName()); int rindex = tag_table.toIndex(parent.getRight().getName()); Arrays.fill(scores, 0.0); normalizedStatistics.getEmissions( sentence.get(t).getWordForm(), scores); double p_w_l = Math.exp(scores[lindex]); double p_w_r = Math.exp(scores[rindex]); double f_l = Math.exp(forward_.score(t, lindex)); double f_r = Math.exp(forward_.score(t, rindex)); double b_l = Math.exp(backward_.score(t, lindex)); double b_r = Math.exp(backward_.score(t, rindex)); double p_l = prior[lindex] / (prior[lindex] + prior[rindex]); double p_r = 1 - p_l; double premerge = (f_l * b_l) + (f_r * b_r); double postmerge = ((p_w_l * p_l + p_w_r * p_r) * (f_l / p_w_l + f_r / p_w_r)) * (p_l * b_l + p_r * b_r); if (sum - premerge + sum * 1e-5 < 0.) { assert false; } assert premerge >= 0; assert postmerge >= 0; double new_premerge = sum; double new_postmerge = sum - premerge + postmerge; if (new_premerge > 0.0 && new_postmerge > 0.0) { double delta = new_postmerge / new_premerge; loss[lindex] += Math.log(delta); } } t++; } } }