/*
Copyright (c) 2009-2011
Speech Group at Informatik 5, Univ. Erlangen-Nuremberg, GERMANY
Korbinian Riedhammer
Tobias Bocklet
This file is part of the Java Speech Toolkit (JSTK).
The JSTK is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
The JSTK is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with the JSTK. If not, see <http://www.gnu.org/licenses/>.
*/
package de.fau.cs.jstk.decoder;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;
import de.fau.cs.jstk.arch.TokenHierarchy;
import de.fau.cs.jstk.arch.TreeNode;
import de.fau.cs.jstk.exceptions.AlignmentException;
import de.fau.cs.jstk.stat.hmm.Alignment;
import de.fau.cs.jstk.stat.hmm.MetaAlignment;
import de.fau.cs.jstk.stat.hmm.State;
/**
* The ViterbiBeamSearch is a classic implementation with either a fixed maximum
* beam size or an adaptive size depending on acoustic similarity
*
* @author sikoried
*/
public class ViterbiBeamSearch {
/**
* Root node of the LST network
*/
private TreeNode root;
/** word insertion penalty (logarithmic) */
private double wip;
/** language model weight */
private double lmwt;
/** beam size */
private int bs;
/** beam width for implicit beam size */
private double bw = 0.;
/** list of active hypotheses */
private ViterbiList active = new ViterbiList();
/** list of expanded hypotheses */
private ViterbiList expanded = new ViterbiList();
/** remember size of expanbded beam of last step */
private int lastExpanded = 0;
/**
* Create a new Decoder instance with the given LST network
* @param root
*/
public ViterbiBeamSearch(TreeNode root) {
this(root, 1., 1.);
}
/**
* Create a new Decoder instance with the given LST nework, language
* model weight and word insertion penalty.
* @param root root of the LST network
* @param siltree TokenTree containing the possible silences
* @param lmWeight (linear)
* @param insertionPenalty word insertion penalty (0...1)
*/
public ViterbiBeamSearch(TreeNode root, double lmWeight, double insertionPenalty) {
this.root = root;
this.lmwt = lmWeight;
this.wip = Math.log(insertionPenalty);
}
/**
* Initialize the beam with the first observation
* @param beamsize maximum size of the beam
* @param beamwidth beam width, Double.MAX_VALUE for infinite beam width
* @param first observation
* @return current beam width
*/
public double initialize(int beamsize, double beamwidth, double [] x) {
this.bs = beamsize;
this.bw = beamwidth;
// make sure the lists are clear
active.clear();
expanded.clear();
// generate the initial active hypotheses
Hypothesis h0 = new Hypothesis(root);
for (TreeNode n : root.children) {
if (n == null)
System.err.println("fail1");
if (n.token == null)
System.err.println("fail2");
if (n.token.hmm == null)
System.err.println("fail3");
if (n.token.hmm.s == null)
System.err.println("fail4");
if (n.token.hmm.s[0] == null)
System.err.println("fail5");
expanded.add(new Hypothesis(h0, n, Math.log(n.token.hmm.s[0].emits(x)), lmwt, wip));
}
// sort and prune if necessary
Collections.sort(expanded);
double best = expanded.get(0).vs;
double width = 0;
while (expanded.size() > 0 && active.size() < bs) {
Hypothesis h = expanded.remove(0);
if ((width = best - h.vs) > bw)
break;
active.add(h);
}
// ready to go, clear expanded list
expanded.clear();
return width;
}
/**
* Feed the next observation to the Viterbi beam search
* @param x
* @return current beam width (best-worst score)
*/
public double step(double [] x) {
List<Hypothesis> nodeExpansions = new LinkedList<Hypothesis>();
while (active.size() > 0) {
Hypothesis h = active.remove(0);
int cs = h.s;
State [] s = h.node.token.hmm.s;
float [] a = h.node.token.hmm.a[cs];
// step 1: intra-node transitions
for (short i = 0; i < s.length; ++i) {
if (a[i] > 0.f)
expanded.vadd(new Hypothesis(h, i, Math.log(a[i]) + Math.log(s[i].emits(x))));
}
// step 2: final state: mark for inter-node transitions
if (cs == s.length - 1)
nodeExpansions.add(h);
}
// node expansions
while (nodeExpansions.size() > 0) {
Hypothesis h = nodeExpansions.remove(0);
for (TreeNode succ : h.node.children) {
if (succ.isWordNode()) {
// generate the null-hypothesis with the word
// this is not an active hypothesis!!
// h.p -> h -> token [null] -> word [null] ---> expansion
Hypothesis token = new Hypothesis(h, h.node, 0.);
Hypothesis word = new Hypothesis(token, succ, lmwt);
// iterate over the lexical successor trees linked with this word leaf
for (TreeNode lst : succ.children) {
for (TreeNode t : lst.children)
expanded.vadd(new Hypothesis(word, t, Math.log(t.token.hmm.s[0].emits(x)), lmwt, wip));
}
} else {
// generate the null-hypothesis with the current node (no lmwt!)
// this is not an active hypothesis!!
// h.p -> h -> token[null] ---> expansion
Hypothesis token = new Hypothesis(h, h.node, 0.);
// no word insertion penalty!
expanded.vadd(new Hypothesis(token, succ, Math.log(succ.token.hmm.s[0].emits(x)), lmwt, 0.));
}
}
}
// sort active hypotheses
Collections.sort(expanded);
// prune down to beam size
Iterator<Hypothesis> it = expanded.iterator();
double best = expanded.get(0).vs;
double width = 0.;
for (int i = 0; i < bs && it.hasNext(); ++i) {
Hypothesis h = it.next();
if ((width = best - h.vs) > bw)
break;
active.add(h);
}
// clear expanded hypotheses
lastExpanded = expanded.size();
expanded.clear();
return width;
}
/**
* Prune all hypotheses which are NOT in final state
*/
public void pruneActiveHypotheses() {
expanded.clear();
while (active.size() > 0) {
Hypothesis h = active.remove(0);
if (h.finalStateActive())
expanded.add(h);
}
active.addAll(expanded);
expanded.clear();
}
/**
* Get the current active beam size
* @return
*/
public int getCurrentBeamSize() {
return active.size();
}
/**
* Get the current Expansion size
* @return
*/
public int getCurrentExpandedSize() {
return lastExpanded;
}
/**
* Conclude the decoding by reducing to active hypotheses which are in the
* final state (and adding respective null-hypotheses)
*/
public void conclude() {
while (active.size() > 0) {
Hypothesis h = active.remove(0);
// if h is in final state, add potential children
if (h.finalStateActive()) {
for (TreeNode succ : h.node.children) {
if (succ.isWordNode()) {
Hypothesis help = new Hypothesis(h, h.node, 0.);
expanded.vadd(new Hypothesis(help, succ, lmwt));
} else
expanded.vadd(new Hypothesis(h, h.node, 0.));
}
} else {
// track back to the last proper hypothesis, but maintain the
// viterbi score!
Hypothesis it = h;
while (it.p != null && !(it.nullhyp && it.node.isWordNode()))
it = it.p;
if (it.p != null) {
it.vs = h.vs;
expanded.vadd(it);
}
}
}
// sort active hypotheses
Collections.sort(expanded);
// prune down to beam size
Iterator<Hypothesis> it = expanded.iterator();
for (int i = 0; i < bs && it.hasNext(); ++i)
active.add(it.next());
// clear expanded hypotheses
expanded.clear();
}
/**
* Feed a list of observations to the Viterbi beam search
* @param list
*/
public void step(List<double []> list) {
for (double [] x : list)
step(x);
}
/**
* Get the best n active hypotheses
* @param n if n == 0 or n > active.size then n = active.size
* @return
*/
public List<Hypothesis> getBestHypotheses(int n) {
if (n == 0)
n = active.size();
return active.subList(0, n > active.size()? active.size() : n);
}
/**
* Return the best hypothesis surviving the beam.
* @return null if the beam is empty
*/
public Hypothesis getBestHypothesis() {
if (active.size() > 0)
return active.peek();
else
return null;
}
/**
* The ViterbiList overwrites the add method of the LinkedList to help the
* Viterbi beam search. In case of an existing Hypothesis, only the better
* scoring one is kept, otherwise the Hypothesis is appended to the list.
*
* @author sikoried
*/
public static final class ViterbiList extends LinkedList<Hypothesis> {
private static final long serialVersionUID = 1L;
/**
* Add the given Hypothesis: Insert if no matching Hypothesis, replace
* matching worse Hypothesis, discard if matching Hypothesis is better.
*
* @return true if the referenced Hypothesis was appended to the list
*/
public boolean vadd(Hypothesis h) {
Iterator<Hypothesis> it = iterator();
while (it.hasNext()) {
Hypothesis cand = it.next();
// hypotheses match
if (cand.equals(h)) {
if(cand.vs < h.vs) {
// new hyp is better than old -> replace!
it.remove();
add(h);
return true;
} else {
// new hyp is worse than old -> don't bother
return false;
}
}
}
// no matching hypothesis found, insert!
return add(h);
}
}
/**
* The Hypothesis couples all necessary information for the decoding process:
* Current Viterbi and acoustic score, node and (HMM) state as well as a pointer
* to the predecessor Hypothesis to track down the origin.
*
* @author sikoried
*/
public static final class Hypothesis implements Comparable<Hypothesis> {
/** predecessor of this hypothesis */
public Hypothesis p = null;
/** associated lexical successor tree node */
public TreeNode node = null;
/** originating Hypothesis */
public Hypothesis origin = null;
/** current state */
public short s = 0;
/** Viterbi score */
public double vs = 0.;
/** acoustic score */
public double as = 0.;
/** is this node a null-hypothesis? */
public boolean nullhyp = false;
/**
* Generate a stub hypothesis as a root hypothesis for the initial
* active states
* @param root root node of a lexical successor tree network
*/
public Hypothesis(TreeNode root) {
p = null;
node = root;
nullhyp = true;
}
/**
* Clone the current hypothesis
*/
public Hypothesis clone() {
Hypothesis nh = new Hypothesis(null);
nh.node = node;
nh.p = p;
nh.s = s;
nh.as = as;
nh.vs = vs;
nh.nullhyp = nullhyp;
nh.origin = origin;
return nh;
}
/**
* Allocate a new Hypothesis modeling a node internal state transition
* to the given state with the given probability
* @param parent parent hypothesis
* @param state target state
* @param prob log-probability for given state transition: log(a[?][state]) + log(s[state].emits(x))
*/
public Hypothesis(Hypothesis parent, short state, double prob) {
this.p = parent;
this.node = parent.node;
this.origin = parent.origin;
s = state;
as = p.as + prob;
vs = p.vs + prob;
}
/**
* Allocate a new Hypothesis expanding to a new lexical TreeNode
* @param parent (null-hypothesis)
* @param n
* @param aprob acoustic log-probability for given state transition log(s[0].emits(x))
* @param lmwt language model weight
* @param wip word insertion penalty (logarithmic), set to 0. for intra-word
*/
public Hypothesis(Hypothesis parent, TreeNode n, double aprob, double lmwt, double wip) {
this.p = parent;
this.node = n;
this.origin = this;
as = aprob;
vs = p.vs + aprob + lmwt * n.f + wip;
}
/**
* Allocate a new Hypothesis as a null-hypothesis carrying the
* associated node (or word leaf) and the respective total acoustic score
* @param parent
* @param n word leaf
* @param lmwt language model weight (set to 0. for intra-word node)
*/
public Hypothesis(Hypothesis parent, TreeNode n, double lmwt) {
this.p = parent;
this.node = n;
this.origin = parent.origin;
// conclude the Viterbi score by adding the final LM weight
as = p.as;
vs = p.vs + lmwt * n.f;
// mark as null-hypothesis
nullhyp = true;
}
/**
* Two Hypotheses are considered equal if they are in the same state s at
* the same time t and share the previous word arc
* @param h
* @return true if states and node match
*/
public boolean equals(Hypothesis h) {
// same HMM state?
if (s != h.s)
return false;
// same originating hypothesis
if (origin == h.origin)
return true;
// the current and origin nodes match
if (node.equals(h.node) && origin.p.node.equals(h.origin.p.node)) {
// trace back the history of null-hypotheses
Hypothesis ita = origin;
Hypothesis itb = h.origin;
while (ita != null && itb != null) {
// get next null-hypotheses
ita = ita.getPreviousNullHypothesis();
itb = itb.getPreviousNullHypothesis();
// we reached the root asynchronously
if (ita == null ^ itb == null)
return false;
// both are at the root, thats fine
if (ita == null && itb == null)
return true;
// ah, same originating model
if (ita.origin == itb.origin)
return true;
// whoops, different history!
if (!ita.node.equals(itb.node))
return false;
}
}
return false;
}
/**
* Trace back to the previous null hypothesis
* @return
*/
public Hypothesis getPreviousNullHypothesis() {
Hypothesis it = p;
while (it != null && !it.nullhyp)
it = it.p;
return it;
}
/**
* Rank hypothesis by their current Viterbi score
*/
public int compareTo(Hypothesis t) {
return (int) Math.signum(t.vs - vs);
}
/**
* Determine, if the Hypothesis is in final state (and thus possible a possible
* word or token). Throws RuntimeException if called on word null-hypothesis
* @return
*/
public boolean finalStateActive() {
if (node.token == null)
throw new RuntimeException("A null-hypothesis does not have any attached HMM!");
return s == node.token.hmm.ns - 1;
}
/**
* Generate a simple String representation of the Hypotheses
*/
public String toString() {
StringBuffer sb = new StringBuffer();
Stack<String> trace = new Stack<String>();
Hypothesis it = this;
while (it.p != null) {
if (it.nullhyp) {
if (it.node.isWordNode())
trace.push(it.node.word.word);
else
trace.push(it.node.toString());
} else
trace.push(it.node.toString() + ":" + it.s);
it = it.p;
}
while (trace.size() > 0)
sb.append(trace.pop() + " ");
return sb.toString();
}
/**
* Get a String representation of this Hypothesis including acoustic scores
*/
public String toDetailedString() {
StringBuffer sb = new StringBuffer();
sb.append(vs + " ");
Stack<String> trace = new Stack<String>();
Hypothesis it = this;
while (it.p != null) {
if (it.nullhyp) {
if (it.node.isWordNode())
trace.push("[" + it.node.word.word + ", " + it.as + "]");
else
trace.push("(" + it.node.toString() + ", " + it.as + ")");
} else
trace.push(it.node.toString() + ":" + it.s);
it = it.p;
}
while (trace.size() > 0)
sb.append(trace.pop() + " ");
return sb.toString();
}
/**
* Generate a compact String representation where intra-node transitions are
* compacted
* @return
*/
public String toCompactString() {
Stack<Hypothesis> trace1 = new Stack<Hypothesis>();
Stack<Integer> trace2 = new Stack<Integer>();
Stack<Integer> trace3 = new Stack<Integer>();
// follow trace, add word leaves
Hypothesis it = clone();
int nodec = 0;
int wordc = 0;
// make sure the trailing thing is considered as null-hypothesis
if (!it.nullhyp)
it.nullhyp = true;
// push the trailing null-hyps
while (it.nullhyp) {
trace1.push(it);
it = it.p;
}
// follow the trace
while (it.p != null) {
if (it.nullhyp) {
trace1.push(it);
if (it.node.isWordNode()) {
trace2.push(wordc);
wordc = 0;
} else {
trace3.push(nodec);
nodec = 0;
}
} else {
nodec++;
wordc++;
}
it = it.p;
}
trace2.push(wordc);
trace3.push(nodec);
StringBuffer sb = new StringBuffer();
sb.append(vs + " ");
while (trace1.size() > 0) {
Hypothesis h = trace1.pop();
if (h.node.isWordNode())
if (trace2.size() > 0)
sb.append("[" + h.node.word.word + ", " + trace2.pop() + ", " + h.as + "] ");
else
sb.append("[" + h.node.word.word + ", 0, " + h.as + "] ");
else {
if (trace3.size() > 0)
sb.append("(" + h.node.toString() + ", " + trace3.pop() + ", " + h.as + ") ");
else
sb.append("(" + h.node.toString() + ", 0, " + h.as + ") ");
}
}
return sb.toString();
}
/**
* Generate a MetaAlignment corresponde Hypothesis
* @param observation
* @param tt
* @return
*/
public MetaAlignment toMetaAlignment(TokenHierarchy th) throws AlignmentException {
// reverse the hypothesis
Stack<Hypothesis> trace = new Stack<Hypothesis>();
Hypothesis it = this;
while (it.p != null) {
// skip null-hypotheses
if (!it.nullhyp)
trace.add(it);
it = it.p;
}
// build up the alignments
List<Alignment> algs = new LinkedList<Alignment>();
TreeNode node = trace.peek().node;
List<Integer> sseq = new LinkedList<Integer>();
while (trace.size() > 0) {
Hypothesis h = trace.pop();
if (!h.node.equals(node)) {
// build state sequence
int [] qstar = new int [sseq.size()];
for (int i = 0; i < qstar.length; ++i)
qstar[i] = sseq.remove(0);
if (qstar.length < node.token.hmm.ns) {
System.err.println("Error: Alignment is shorter than model!");
} else {
// generate alignment
Alignment a = new Alignment(node.token.hmm, null, qstar);
algs.add(a);
}
// reset the pointers
sseq = new LinkedList<Integer>();
node = h.node;
}
// build state sequence
sseq.add((int) h.s);
}
// there is an unfinished state sequence
if (sseq.size() > 0) {
// build state sequence
int [] qstar = new int [sseq.size()];
for (int i = 0; i < qstar.length; ++i)
qstar[i] = sseq.remove(0);
Alignment a = new Alignment(node.token.hmm, null, qstar);
algs.add(a);
}
MetaAlignment ma = new MetaAlignment(th, algs);
return ma;
}
/**
* Generate a list of Hypotheses containing only the word null-hypotheses
* @return word list in correct time order
*/
public synchronized List<Hypothesis> extractWords() {
LinkedList<Hypothesis> ws = new LinkedList<Hypothesis>();
Hypothesis it = this;
// follow trace, add word leaves
while (it.p != null) {
if (it.nullhyp && it.node.isWordNode())
ws.add(0, it);
it = it.p;
}
return ws;
}
/**
* Generate a list of Hypotheses containing only the Token null-hypotheses
* @return token list in correct time order
*/
public synchronized List<Hypothesis> extractTokens() {
LinkedList<Hypothesis> ts = new LinkedList<Hypothesis>();
Hypothesis it = this;
while (it.p != null) {
if (it.nullhyp && !it.node.isWordNode())
ts.add(0, it);
it = it.p;
}
return ts;
}
}
}