/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.trainer; import edu.cmu.sphinx.frontend.*; import edu.cmu.sphinx.frontend.util.StreamCepstrumSource; import edu.cmu.sphinx.linguist.acoustic.HMMState; import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM; import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState; import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.props.PropertyException; import edu.cmu.sphinx.util.props.PropertySheet; import edu.cmu.sphinx.util.props.S4Component; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; /** Provides mechanisms for computing statistics given a set of states and input data. */ public class BaumWelchLearner implements Learner { @S4Component(type = FrontEnd.class) public static final String FRONT_END = "frontend"; private FrontEnd frontEnd; @S4Component(type = StreamCepstrumSource.class) public static final String DATA_SOURCE = "source"; private StreamCepstrumSource dataSource; private LogMath logMath; /* * The logger for this class */ private static Logger logger = Logger.getLogger("edu.cmu.sphinx.trainer.BaumWelch"); private Data curFeature; private UtteranceGraph graph; private TrainerScore[][] scoreArray; private int lastFeatureIndex; private int currentFeatureIndex; private float[] betas; private float[] outputProbs; private float[] componentScores; private float[] probCurrentFrame; private float totalLogScore; public void newProperties(PropertySheet ps) throws PropertyException { logMath = LogMath.getLogMath(); dataSource = (StreamCepstrumSource) ps.getComponent(DATA_SOURCE); frontEnd = (FrontEnd) ps.getComponent(FRONT_END); frontEnd.setDataSource(dataSource); } /* Initialize and return the frontend based on the given sphinx properties. */ protected FrontEnd getFrontEnd() { return frontEnd; } /** * Sets the learner to use a utterance. * * @param utterance the utterance * @throws IOException if error occurred */ public void setUtterance(Utterance utterance) throws IOException { String file = utterance.toString(); InputStream is = new FileInputStream(file); dataSource.setInputStream(is, false); } /** * Gets a single frame of speech. * * @return if success */ private boolean getFeature() { try { curFeature = frontEnd.getData(); if (curFeature == null) { return false; } if (curFeature instanceof DataStartSignal) { curFeature = frontEnd.getData(); if (curFeature == null) { return false; } } if (curFeature instanceof DataEndSignal) { return false; } if (curFeature instanceof Signal) { throw new Error("Can't score non-content feature"); } } catch (DataProcessingException dpe) { System.out.println("DataProcessingException " + dpe); dpe.printStackTrace(); return false; } return true; } /** Starts the Learner. */ public void start() { } /** Stops the Learner. */ public void stop() { } /** * Initializes computation for current utterance and utterance graph. * * @param utterance the current utterance * @param graph the current utterance graph * @throws IOException if exception occured */ public void initializeComputation(Utterance utterance, UtteranceGraph graph) throws IOException { setUtterance(utterance); setGraph(graph); } /** * Implements the setGraph method. * * @param graph the graph */ public void setGraph(UtteranceGraph graph) { this.graph = graph; } /** * Prepares the learner for returning scores, one at a time. To do so, it performs the full forward pass, but * returns the scores for the backward pass one feature frame at a time. */ private TrainerScore[][] prepareScore() { // scoreList will contain a list of score, which in turn are a // vector of TrainerScore elements. List<TrainerScore[]> scoreList = new ArrayList<TrainerScore[]>(); int numStates = graph.size(); TrainerScore[] score = new TrainerScore[numStates]; betas = new float[numStates]; outputProbs = new float[numStates]; // First we do the forward pass. We need this before we can // return any probability. When we're doing the backward pass, // we can finally return a score for each call of this method. probCurrentFrame = new float[numStates]; // Initialization of probCurrentFrame for the alpha computation Node initialNode = graph.getInitialNode(); int indexInitialNode = graph.indexOf(initialNode); for (int i = 0; i < numStates; i++) { probCurrentFrame[i] = LogMath.LOG_ZERO; } // Overwrite in the right position probCurrentFrame[indexInitialNode] = 0.0f; for (initialNode.startOutgoingEdgeIterator(); initialNode.hasMoreOutgoingEdges();) { Edge edge = initialNode.nextOutgoingEdge(); Node node = edge.getDestination(); int index = graph.indexOf(node); if (!node.isType("STATE")) { // Certainly non-emitting, if it's not in an HMM. probCurrentFrame[index] = 0.0f; } else { // See if it's the last state in the HMM, i.e., if // it's non-emitting. HMMState state = (HMMState) node.getObject(); if (!state.isEmitting()) { probCurrentFrame[index] = 0.0f; } assert false; } } // If getFeature() is true, curFeature contains a valid // Feature. If not, a problem or EOF was encountered. lastFeatureIndex = 0; while (getFeature()) { forwardPass(score); scoreList.add(score); lastFeatureIndex++; } logger.info("Feature frames read: " + lastFeatureIndex); // Prepare for beta computation for (int i = 0; i < probCurrentFrame.length; i++) { probCurrentFrame[i] = LogMath.LOG_ZERO; } Node finalNode = graph.getFinalNode(); int indexFinalNode = graph.indexOf(finalNode); // Overwrite in the right position probCurrentFrame[indexFinalNode] = 0.0f; for (finalNode.startIncomingEdgeIterator(); finalNode.hasMoreIncomingEdges();) { Edge edge = finalNode.nextIncomingEdge(); Node node = edge.getSource(); int index = graph.indexOf(node); if (!node.isType("STATE")) { // Certainly non-emitting, if it's not in an HMM. probCurrentFrame[index] = 0.0f; assert false; } else { // See if it's the last state in the HMM, i.e., if // it's non-emitting. HMMState state = (HMMState) node.getObject(); if (!state.isEmitting()) { probCurrentFrame[index] = 0.0f; } } } return scoreList.toArray(new TrainerScore[scoreList.size()][]); } /** * Gets the TrainerScore for the next frame * * @return the TrainerScore, or null if EOF was found */ public TrainerScore[] getScore() { TrainerScore[] score; if (scoreArray == null) { // Do the forward pass, and create the necessary arrays scoreArray = prepareScore(); currentFeatureIndex = lastFeatureIndex; } currentFeatureIndex--; if (currentFeatureIndex >= 0) { float logScore = LogMath.LOG_ZERO; score = scoreArray[currentFeatureIndex]; assert score.length == betas.length; backwardPass(score); for (int i = 0; i < betas.length; i++) { score[i].setGamma(); logScore = logMath.addAsLinear(logScore, score[i].getGamma()); } if (currentFeatureIndex == lastFeatureIndex - 1) { TrainerScore.setLogLikelihood(logScore); totalLogScore = logScore; } else { if (Math.abs(totalLogScore - logScore) > Math.abs(totalLogScore)) { System.out.println("WARNING: log probabilities differ: " + totalLogScore + " and " + logScore); } } return score; } else { // We need to clear this, so we start the next iteration // on a clean plate. scoreArray = null; return null; } } /** * Computes the acoustic scores using the current Feature and a given node in the graph. * * @param index the graph index * @return the overall acoustic score */ private float calculateScores(int index) { float logScore; // Find the HMM state for this node SenoneHMMState state = (SenoneHMMState) graph.getNode(index).getObject(); if ((state != null) && (state.isEmitting())) { // Compute the scores for each mixture component in this state componentScores = state.calculateComponentScore(curFeature); // Compute the overall score for this state logScore = state.getScore(curFeature); // For CI models, for now, we only try to use mixtures // with one component assert componentScores.length == 1; } else { componentScores = null; logScore = 0.0f; } return logScore; } /** * Does the forward pass, one frame at a time. * * @param score the objects transferring info to the buffers */ private void forwardPass(TrainerScore[] score) { // Let's precompute the acoustic probabilities and create the // score object, one for each state for (int i = 0; i < graph.size(); i++) { outputProbs[i] = calculateScores(i); score[i] = new TrainerScore(curFeature, outputProbs[i], (HMMState) graph.getNode(i).getObject(), componentScores); score[i].setAlpha(probCurrentFrame[i]); } // Now, the forward pass. float[] probPreviousFrame = probCurrentFrame; probCurrentFrame = new float[graph.size()]; // First, the emitting states. We have to do this because the // emitting states use probabilities from the previous // frame. The non-emitting states, however, since they don't // consume frames, use probabilities from the current frame for (int indexNode = 0; indexNode < graph.size(); indexNode++) { Node node = graph.getNode(indexNode); // Treat dummy node (and initial and final nodes) the same // as non-emitting if (!node.isType("STATE")) { continue; } SenoneHMMState state = (SenoneHMMState) node.getObject(); SenoneHMM hmm = (SenoneHMM) state.getHMM(); if (!state.isEmitting()) { continue; } // Initialize the current frame probability with 0.0f, log scale probCurrentFrame[indexNode] = LogMath.LOG_ZERO; for (node.startIncomingEdgeIterator(); node.hasMoreIncomingEdges();) { // Finds out what the previous node and previous state are Node previousNode = node.nextIncomingEdge().getSource(); int indexPreviousNode = graph.indexOf(previousNode); HMMState previousState = (HMMState) previousNode.getObject(); float logTransitionProbability; // previous state could be have an associated hmm state... if (previousState != null) { // Make sure that the transition happened from a state // that either is in the same model, or was a // non-emitting state assert ((!previousState.isEmitting()) || (previousState.getHMM() == hmm)); if (!previousState.isEmitting()) { logTransitionProbability = 0.0f; } else { logTransitionProbability = hmm.getTransitionProbability( previousState.getState(), state.getState()); } } else { // Previous state is a dummy state or beginning of // utterance. logTransitionProbability = 0.0f; } // Adds the alpha and transition from the previous // state into the current alpha probCurrentFrame[indexNode] = logMath.addAsLinear(probCurrentFrame[indexNode], probPreviousFrame[indexPreviousNode] + logTransitionProbability); // System.out.println("State= " + indexNode + " curr " // + probCurrentFrame[indexNode] + " prev " + // probPreviousFrame[indexNode] + " trans " + // logTransitionProbability); } // Finally, multiply by this state's output probability for the // current Feature (add in log scale) probCurrentFrame[indexNode] += outputProbs[indexNode]; // System.out.println("State= " + indexNode + " alpha= " + // probCurrentFrame[indexNode]); score[indexNode].setAlpha(probCurrentFrame[indexNode]); } // Finally, the non-emitting states for (int indexNode = 0; indexNode < graph.size(); indexNode++) { Node node = graph.getNode(indexNode); HMMState state = null; SenoneHMM hmm = null; if (node.isType("STATE")) { state = (HMMState) node.getObject(); hmm = (SenoneHMM) state.getHMM(); if (state.isEmitting()) { continue; } } else if (graph.isInitialNode(node)) { score[indexNode].setAlpha(LogMath.LOG_ZERO); probCurrentFrame[indexNode] = LogMath.LOG_ZERO; continue; } // Initialize the current frame probability 0.0f, log scale probCurrentFrame[indexNode] = LogMath.LOG_ZERO; for (node.startIncomingEdgeIterator(); node.hasMoreIncomingEdges();) { float logTransitionProbability; // Finds out what the previous node and previous state are Node previousNode = node.nextIncomingEdge().getSource(); int indexPreviousNode = graph.indexOf(previousNode); if (previousNode.isType("STATE")) { HMMState previousState = (HMMState) previousNode.getObject(); // Make sure that the transition happened from a // state that either is in the same model, or was // a non-emitting state assert ((!previousState.isEmitting()) || (previousState.getHMM() == hmm)); if (!previousState.isEmitting()) { logTransitionProbability = 0.0f; } else { // previousState == state logTransitionProbability = hmm.getTransitionProbability( previousState.getState(), state.getState()); } } else { logTransitionProbability = 0.0f; } // Adds the alpha and transition from the previous // state into the current alpha probCurrentFrame[indexNode] = logMath.addAsLinear(probCurrentFrame[indexNode], probCurrentFrame[indexPreviousNode] + logTransitionProbability); // System.out.println("State= " + indexNode + " curr " // + probCurrentFrame[indexNode] + " prev " + // probPreviousFrame[indexNode] + " trans " + // logTransitionProbability); } // System.out.println("State= " + indexNode + " alpha= " + // probCurrentFrame[indexNode]); // Non-emitting states have the equivalent of output // probability of 1.0. In log scale, this is the same as // adding 0.0f, or doing nothing. score[indexNode].setAlpha(probCurrentFrame[indexNode]); } } /** * Does the backward pass, one frame at a time. * * @param score the feature to be used */ private void backwardPass(TrainerScore[] score) { // Now, the backward pass. for (int i = 0; i < graph.size(); i++) { outputProbs[i] = score[i].getScore(); score[i].setBeta(probCurrentFrame[i]); } float[] probNextFrame = probCurrentFrame; probCurrentFrame = new float[graph.size()]; // First, the emitting states for (int indexNode = 0; indexNode < graph.size(); indexNode++) { Node node = graph.getNode(indexNode); // Treat dummy node (and initial and final nodes) the same // as non-emitting if (!node.isType("STATE")) { continue; } HMMState state = (HMMState) node.getObject(); SenoneHMM hmm = (SenoneHMM) state.getHMM(); if (!state.isEmitting()) { continue; } // Initialize the current frame probability with log // probability of log(0f) probCurrentFrame[indexNode] = LogMath.LOG_ZERO; for (node.startOutgoingEdgeIterator(); node.hasMoreOutgoingEdges();) { float logTransitionProbability; // Finds out what the next node and next state are Node nextNode = node.nextOutgoingEdge().getDestination(); int indexNextNode = graph.indexOf(nextNode); HMMState nextState = (HMMState) nextNode.getObject(); if (nextState != null) { // Make sure that the transition happened to a // non-emitting state, or to the same model assert ((!nextState.isEmitting()) || (nextState.getHMM() == hmm)); if (nextState.getHMM() != hmm) { logTransitionProbability = 0.0f; } else { logTransitionProbability = hmm.getTransitionProbability(state.getState(), nextState.getState()); } } else { // Next state is a dummy state or beginning of // utterance. logTransitionProbability = 0.0f; } // Adds the beta, the output prob, and the transition // from the next state into the current beta probCurrentFrame[indexNode] = logMath.addAsLinear(probCurrentFrame[indexNode], probNextFrame[indexNextNode] + logTransitionProbability + outputProbs[indexNextNode]); } // System.out.println("State= " + indexNode + " beta= " + probCurrentFrame[indexNode]); score[indexNode].setBeta(probCurrentFrame[indexNode]); } // Now, the non-emitting states // We have to go backwards because for non-emitting states we // use the current frame probability, and we need to refer to // states that are downstream in the graph for (int indexNode = graph.size() - 1; indexNode >= 0; indexNode--) { Node node = graph.getNode(indexNode); HMMState state = null; if (node.isType("STATE")) { state = (HMMState) node.getObject(); if (state.isEmitting()) { continue; } } else if (graph.isFinalNode(node)) { score[indexNode].setBeta(LogMath.LOG_ZERO); probCurrentFrame[indexNode] = LogMath.LOG_ZERO; continue; } // Initialize the current frame probability with log(0f) probCurrentFrame[indexNode] = LogMath.LOG_ZERO; for (node.startOutgoingEdgeIterator(); node.hasMoreOutgoingEdges();) { float logTransitionProbability; // Finds out what the next node and next state are Node nextNode = node.nextOutgoingEdge().getDestination(); int indexNextNode = graph.indexOf(nextNode); if (nextNode.isType("STATE")) { HMMState nextState = (HMMState) nextNode.getObject(); // Make sure that the transition happened to a // state that either is the same, or is emitting assert ((nextState.isEmitting()) || (nextState == state)); // In any case, the transition (at this point) is // assumed to be 1.0f, or 0.0f in log scale. logTransitionProbability = 0.0f; /* if (!nextState.isEmitting()) { logTransitionProbability = 0.0f; } else { logTransitionProbability = hmm.getTransitionProbability(state.getState(), nextState.getState()); } */ } else { logTransitionProbability = 0.0f; } // Adds the beta, the transition, and the output prob // from the next state into the current beta probCurrentFrame[indexNode] = logMath.addAsLinear(probCurrentFrame[indexNode], probCurrentFrame[indexNextNode] + logTransitionProbability); } // System.out.println("State= " + indexNode + " beta= " + probCurrentFrame[indexNode]); score[indexNode].setBeta(probCurrentFrame[indexNode]); } } /* Pseudo code: forward pass: token = maketoken(initialstate); List initialTokenlist = new List; newtokenlist.add(token); // Initial token is on a nonemitting state; no need to score; List newList = expandToEmittingStateList(initialTokenList){ while (morefeatures){ scoreTokenList(emittingTokenList, featurevector[timestamp]); pruneTokenList(emittingTokenList); List newList = expandToEmittingStateList(emittingTokenList){ timestamp++; } // Some logic to expand to a final nonemitting state (how)? expandToNonEmittingStates(emittingTokenList); */ /* private void forwardPass() { ActiveList activelist = new FastActiveList(createInitialToken()); AcousticScorer acousticScorer = new ThreadedAcousticScorer(); FeatureFrame featureFrame = frontEnd.getFeatureFrame(1, ""); Pruner pruner = new SimplePruner(); // Initialization code pushing initial state to emitting state here while ((featureFrame.getFeatures() != null)) { ActiveList nextActiveList = new FastActiveList(); // At this point we have only emitting states. We score // and prune them ActiveList emittingStateList = new FastActiveList(); // activelist.getEmittingStateList(); acousticScorer.calculateScores(emittingStateList.getTokens()); // The pruner must clear up references to pruned objects emittingStateList = pruner.prune( emittingStateList); expandStateList(emittingStateList, nextActiveList); while (nextActiveList.hasNonEmittingStates()){ // extractNonEmittingStateList will pull out the list // of nonemitting states completely from the // nextActiveList. At this point nextActiveList does // not have a list of nonemitting states and must // instantiate a new one. ActiveList nonEmittingStateList = nextActiveList.extractNonEmittingStateList(); nonEmittingStateList = pruner.prune(nonEmittingStateList); expandStateList(nonEmittingStateList, nextActiveList); } activeList = newActiveList; } } */ /* Pseudo code backward pass: state = finaltoken.state.wholelistofeverythingthatcouldbefinal; while (moreTokensAtCurrentTime) { Token token = nextToken(); State state = token.state; state.gamma = state.logalpha + state.logbeta - logtotalprobability; SentenceHMM.updateState(state,state.gamma,vector[state.timestamp]); // state.update (state.gamma, vector[state.timestamp], updatefunction()); while token.hasMoreIncomingEdges() { Edge transition = token.nextIncomingEdge(); double logalpha = transition.source.alpha; double logbeta = transition.destination.beta; double logtransition = transition.transitionprob; // transition.posterior = alpha*transition*beta / // totalprobability; double localtransitionbetascore = logtransition + logbeta + transition.destination.logscore; double transition.posterior = localtransitionbetascore + logalpha - logtotalprobability; transition.updateaccumulator(transition.posterior); // transition.updateaccumulator(transition.posterior, updatefunction()); SentenceHMM.updateTransition(transition, transitionstate,state.gamma); transition.source.beta = Logadd(transition.source.beta, localtransitionbetascore); } } */ /* private void expandStateList(ActiveList stateList, ActiveList nextActiveList) { while (stateList.hasMoreTokens()) { Token token = emittingStateList.getNextToken(); // First get list of links to possible future states List successorList = getSuccessors(token); while (successorList.hasMoreEntries()) { UtteranceGraphEdge edge = successorList.getNextEntry(); // create a token for the future state, if its not // already in active list; The active list will check // for the key "edge.destination()" in both of its // lists if (nextActiveList.hasState(edge.destination())) { Token newToken = nextActiveList.getTokenForState(edge.destination()); } else { Token newToken = new Token(edge.destination()); } // create a link between current state and future state TrainerLink newlink = new TrainerLink(edge, token, newToken); newlink.logScore = token.logScore + edge.transition.logprob(); // add link to the appropriate lists for source and // destination tokens token.addOutGoingLink(newlink); newToken.addIncomingLink(newlink); newToken.alpha = logAdd(newToken.alpha, newlink.logScore); // At this point, we have placed a new token in the // successor state, and linked the token at the // current state to the token at the non-emitting // states. // Add token to appropriate active list nextActiveList.add(newToken); } } } */ /* private void expandToEmittingStateList(List tokenList){ List emittingTokenList = new List(); do { List nonEmittingTokenList = new List(); expandtokens(newtokenlist, emittingTokenList, nonemittingTokenList); while (nonEmittingTokenList.length() != 0); return emittingTokenList; } } */ /* private void expandtokens(List tokens, List nonEmittingStateList, List EmittingStateList){ while (moreTokens){ sucessorlist = SentenceHMM.gettransitions(nextToken()); while (moretransitions()){ transition = successor; State destinationState = successor.state; newtoken = gettokenfromHash(destinationState, currenttimestamp); newtoken.logscore = Logadd(newtoken.logscore, token.logscore + transition.logscore); // Add transition to newtoken predecessor list? // Add transition to token sucessor list // Should we define a token "arc" for this. ?? if (state.isemitting) EmittingStateList.add(newtoken); else nonEmittingStateList.add(newtoken); } } } */ }