// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.core.lattice; import java.util.Collections; import java.util.List; import marmot.core.State; import marmot.util.Numerics; public class BackwardSequenceLattice { private double[][] lattice_; private List<List<State>> candidates_; private State boundary_; public BackwardSequenceLattice(List<List<State>> candidates, State boundary) { candidates_ = candidates; boundary_ = boundary; assert candidates_.get(candidates_.size() - 1).size() == 1; State state = candidates_.get(candidates_.size() - 1).get(0); assert state == state.getZeroOrderState(); assert state.getIndex() == boundary_.getIndex(); } private double test(int index, List<List<State>> candidates, boolean print) { double score_sum = Double.NEGATIVE_INFINITY; if (candidates.isEmpty()) { return 0; } else { int state_index = 0; for (State state : candidates.get(0)) { State transition = state.getTransition(index); if (transition != null) { double rec = test(state_index, candidates.subList(1, candidates.size()), false); if (print) { System.err.format("%d %g %g %g\n", state_index, state.getScore(), transition.getScore(), rec); } double score = transition.getScore() + state.getScore() + rec; score_sum = Numerics.sumLogProb(score_sum, score); } else { System.err.format("%d null\n", state_index); } state_index++; } } lattice_[candidates_.size() - candidates.size()][index] = score_sum; return score_sum; } public void init() { lattice_ = new double[candidates_.size()][]; for (int index = candidates_.size() - 1; index >= 0; index--) { List<State> previous_states = candidates_.get(index); List<State> states; if (index == 0) { states = Collections.singletonList(boundary_); } else { states = candidates_.get(index - 1); } lattice_[index] = new double[states.size()]; for (int state_index = 0; state_index < states.size(); state_index++) { double score_sum = Double.NEGATIVE_INFINITY; int previous_state_index = 0; for (State previous_state : previous_states) { State transition = previous_state .getTransition(state_index); if (transition == null) { previous_state_index++; continue; } double score = previous_state.getScore() + transition.getScore(); if (index + 1 < candidates_.size()) { score += lattice_[index + 1][previous_state_index]; } score_sum = Numerics.sumLogProb(score_sum, score); previous_state_index++; } lattice_[index][state_index] = score_sum; } } assert lattice_[0].length == 1; } double partitionFunction() { return lattice_[0][0]; } public double get(int index, int state_index) { if (index == candidates_.size()) return 0; return lattice_[index][state_index]; } public void reinit() { lattice_ = new double[candidates_.size()][]; for (int index = candidates_.size() - 1; index >= 0; index--) { List<State> states; if (index == 0) { states = Collections.singletonList(boundary_); } else { states = candidates_.get(index - 1); } lattice_[index] = new double[states.size()]; } System.err.println(test(0, candidates_, false)); System.err.println(partitionFunction()); for (int index = candidates_.size() - 1; index >= 0; index--) { List<State> previous_states = candidates_.get(index); List<State> states; if (index == 0) { states = Collections.singletonList(boundary_); } else { states = candidates_.get(index - 1); } //lattice_[index] = new double[states.size()]; for (int state_index = 0; state_index < states.size(); state_index++) { double score_sum = Double.NEGATIVE_INFINITY; int previous_state_index = 0; System.err.format("STATE\n"); for (State previous_state : previous_states) { State transition = previous_state .getTransition(state_index); if (transition == null) { System.err.format("%d null\n", state_index); previous_state_index++; continue; } double score = previous_state.getScore() + transition.getScore(); System.err.format("%d %g %g", previous_state_index, previous_state.getScore(), transition.getScore()); if (index + 1 < candidates_.size()) { score += lattice_[index + 1][previous_state_index]; System.err.format(" %g\n", lattice_[index + 1][previous_state_index]); } System.err.format("\n"); score_sum = Numerics.sumLogProb(score_sum, score); previous_state_index++; } try { diffTest(lattice_[index][state_index], score_sum); } catch (RuntimeException e) { System.err.println(); System.err.println(candidates_.size()); System.err.println(index); System.err.println(state_index); System.err.println(); System.err.println(); System.err.println(lattice_[index][state_index]); double f = test(state_index, candidates_.subList(index, candidates_.size()), true); System.err.println(lattice_[index][state_index]); System.err.println(f); System.err.println(); System.err.println(); throw e; } lattice_[index][state_index] = score_sum; } } assert lattice_[0].length == 1; } protected double diffTest(double a, double b) { double diff = Math.abs(a - b); if (diff > 1.e-10) { throw new RuntimeException(String.format("test failed: %g %g : %g", a, b, diff)); } return diff; } }