/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.decoder.search; import edu.cmu.sphinx.decoder.pruner.Pruner; import edu.cmu.sphinx.decoder.scorer.AcousticScorer; import edu.cmu.sphinx.frontend.Data; import edu.cmu.sphinx.linguist.Linguist; import edu.cmu.sphinx.linguist.SearchState; import edu.cmu.sphinx.linguist.SearchStateArc; import edu.cmu.sphinx.linguist.WordSearchState; import edu.cmu.sphinx.result.Result; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.StatisticsVariable; import edu.cmu.sphinx.util.Timer; import edu.cmu.sphinx.util.TimerPool; import edu.cmu.sphinx.util.props.*; import java.util.*; import java.util.logging.Level; import java.util.logging.Logger; import java.io.IOException; /** * Provides the breadth first search. To perform recognition an application should call initialize before recognition * begins, and repeatedly call <code> recognize </code> until Result.isFinal() returns true. Once a final result has * been obtained, <code> terminate </code> should be called. * <p> * All scores and probabilities are maintained in the log math log domain. * <p> * For information about breadth first search please refer to "Spoken Language Processing", X. Huang, PTR */ // TODO - need to add in timing code. public class SimpleBreadthFirstSearchManager extends TokenSearchManager { /** The property that defines the name of the linguist to be used by this search manager. */ @S4Component(type = Linguist.class) public final static String PROP_LINGUIST = "linguist"; /** The property that defines the name of the linguist to be used by this search manager. */ @S4Component(type = Pruner.class) public final static String PROP_PRUNER = "pruner"; /** The property that defines the name of the scorer to be used by this search manager. */ @S4Component(type = AcousticScorer.class) public final static String PROP_SCORER = "scorer"; /** The property that defines the name of the active list factory to be used by this search manager. */ @S4Component(type = ActiveListFactory.class) public final static String PROP_ACTIVE_LIST_FACTORY = "activeListFactory"; /** * The property that when set to <code>true</code> will cause the recognizer to count up all the tokens in the * active list after every frame. */ @S4Boolean(defaultValue = false) public final static String PROP_SHOW_TOKEN_COUNT = "showTokenCount"; /** * The property that sets the minimum score relative to the maximum score in the word list for pruning. Words with a * score less than relativeBeamWidth * maximumScore will be pruned from the list */ @S4Double(defaultValue = 0.0) public final static String PROP_RELATIVE_WORD_BEAM_WIDTH = "relativeWordBeamWidth"; /** * The property that controls whether or not relative beam pruning will be performed on the entry into a * state. */ @S4Boolean(defaultValue = false) public final static String PROP_WANT_ENTRY_PRUNING = "wantEntryPruning"; /** * The property that controls the number of frames processed for every time the decode growth step is skipped. * Setting this property to zero disables grow skipping. Setting this number to a small integer will increase the * speed of the decoder but will also decrease its accuracy. The higher the number, the less often the grow code is * skipped. */ @S4Integer(defaultValue = 0) public final static String PROP_GROW_SKIP_INTERVAL = "growSkipInterval"; protected Linguist linguist; // Provides grammar/language info private Pruner pruner; // used to prune the active list private AcousticScorer scorer; // used to score the active list protected int currentFrameNumber; // the current frame number protected long currentCollectTime; // the current frame number protected ActiveList activeList; // the list of active tokens protected List<Token> resultList; // the current set of results protected LogMath logMath; private Logger logger; private String name; // ------------------------------------ // monitoring data // ------------------------------------ private Timer scoreTimer; // TODO move these timers out private Timer pruneTimer; protected Timer growTimer; private StatisticsVariable totalTokensScored; private StatisticsVariable tokensPerSecond; private StatisticsVariable curTokensScored; private StatisticsVariable tokensCreated; private StatisticsVariable viterbiPruned; private StatisticsVariable beamPruned; // ------------------------------------ // Working data // ------------------------------------ protected boolean showTokenCount; private boolean wantEntryPruning; protected Map<SearchState, Token> bestTokenMap; private float logRelativeWordBeamWidth; private int totalHmms; private double startTime; private float threshold; private float wordThreshold; private int growSkipInterval; protected ActiveListFactory activeListFactory; protected boolean streamEnd; public SimpleBreadthFirstSearchManager() { } /** * Creates a manager for simple search * * @param linguist linguist to configure search space * @param pruner pruner to prune extra paths * @param scorer scorer to estimate token probability * @param activeListFactory factory for list of tokens * @param showTokenCount show count of the tokens during decoding * @param relativeWordBeamWidth relative pruning beam for lookahead * @param growSkipInterval interval to skip growth step * @param wantEntryPruning entry pruning */ public SimpleBreadthFirstSearchManager(Linguist linguist, Pruner pruner, AcousticScorer scorer, ActiveListFactory activeListFactory, boolean showTokenCount, double relativeWordBeamWidth, int growSkipInterval, boolean wantEntryPruning) { this.name = getClass().getName(); this.logger = Logger.getLogger(name); this.logMath = LogMath.getLogMath(); this.linguist = linguist; this.pruner = pruner; this.scorer = scorer; this.activeListFactory = activeListFactory; this.showTokenCount = showTokenCount; this.growSkipInterval = growSkipInterval; this.wantEntryPruning = wantEntryPruning; this.logRelativeWordBeamWidth = logMath.linearToLog(relativeWordBeamWidth); this.keepAllTokens = true; } @Override public void newProperties(PropertySheet ps) throws PropertyException { super.newProperties(ps); logMath = LogMath.getLogMath(); logger = ps.getLogger(); name = ps.getInstanceName(); linguist = (Linguist) ps.getComponent(PROP_LINGUIST); pruner = (Pruner) ps.getComponent(PROP_PRUNER); scorer = (AcousticScorer) ps.getComponent(PROP_SCORER); activeListFactory = (ActiveListFactory) ps.getComponent(PROP_ACTIVE_LIST_FACTORY); showTokenCount = ps.getBoolean(PROP_SHOW_TOKEN_COUNT); double relativeWordBeamWidth = ps.getDouble(PROP_RELATIVE_WORD_BEAM_WIDTH); growSkipInterval = ps.getInt(PROP_GROW_SKIP_INTERVAL); wantEntryPruning = ps.getBoolean(PROP_WANT_ENTRY_PRUNING); logRelativeWordBeamWidth = logMath.linearToLog(relativeWordBeamWidth); this.keepAllTokens = true; } /** Called at the start of recognition. Gets the search manager ready to recognize */ public void startRecognition() { logger.finer("starting recognition"); linguist.startRecognition(); pruner.startRecognition(); scorer.startRecognition(); localStart(); if (startTime == 0.0) { startTime = System.currentTimeMillis(); } } /** * Performs the recognition for the given number of frames. * * @param nFrames the number of frames to recognize * @return the current result or null if there is no Result (due to the lack of frames to recognize) */ public Result recognize(int nFrames) { boolean done = false; Result result = null; streamEnd = false; for (int i = 0; i < nFrames && !done; i++) { done = recognize(); } // generate a new temporary result if the current token is based on a final search state // remark: the first check for not null is necessary in cases that the search space does not contain scoreable tokens. if (activeList.getBestToken() != null) { // to make the current result as correct as possible we undo the last search graph expansion here ActiveList fixedList = undoLastGrowStep(); // Now create the result using the fixed active-list. if (!streamEnd) result = new Result(fixedList, resultList, currentFrameNumber, done, linguist.getSearchGraph().getWordTokenFirst(), false); } if (showTokenCount) { showTokenCount(); } return result; } /** * Because the growBranches() is called although no data is left after the last speech frame, the ordering of the * active-list might depend on the transition probabilities and (penalty-scores) only. Therefore we need to undo the last * grow-step up to final states or the last emitting state in order to fix the list. * @return newly created list */ protected ActiveList undoLastGrowStep() { ActiveList fixedList = activeList.newInstance(); for (Token token : activeList) { Token curToken = token.getPredecessor(); // remove the final states that are not the real final ones because they're just hide prior final tokens: while (curToken.getPredecessor() != null && ( (curToken.isFinal() && curToken.getPredecessor() != null && !curToken.getPredecessor().isFinal()) || (curToken.isEmitting() && curToken.getData() == null) // the so long not scored tokens || (!curToken.isFinal() && !curToken.isEmitting()))) { curToken = curToken.getPredecessor(); } fixedList.add(curToken); } return fixedList; } /** Terminates a recognition */ public void stopRecognition() { localStop(); scorer.stopRecognition(); pruner.stopRecognition(); linguist.stopRecognition(); logger.finer("recognition stopped"); } /** * Performs recognition for one frame. Returns true if recognition has been completed. * * @return <code>true</code> if recognition is completed. */ protected boolean recognize() { boolean more = scoreTokens(); // score emitting tokens if (more) { pruneBranches(); // eliminate poor branches currentFrameNumber++; if (growSkipInterval == 0 || (currentFrameNumber % growSkipInterval) != 0) { growBranches(); // extend remaining branches } } return !more; } /** Gets the initial grammar node from the linguist and creates a GrammarNodeToken */ protected void localStart() { currentFrameNumber = 0; curTokensScored.value = 0; ActiveList newActiveList = activeListFactory.newInstance(); SearchState state = linguist.getSearchGraph().getInitialState(); newActiveList.add(new Token(state, -1)); activeList = newActiveList; growBranches(); } /** Local cleanup for this search manager */ protected void localStop() { } /** * Goes through the active list of tokens and expands each token, finding the set of successor tokens until all the * successor tokens are emitting tokens. */ protected void growBranches() { int mapSize = activeList.size() * 10; if (mapSize == 0) { mapSize = 1; } growTimer.start(); bestTokenMap = new HashMap<SearchState, Token>(mapSize); ActiveList oldActiveList = activeList; resultList = new LinkedList<Token>(); activeList = activeListFactory.newInstance(); threshold = oldActiveList.getBeamThreshold(); wordThreshold = oldActiveList.getBestScore() + logRelativeWordBeamWidth; for (Token token : oldActiveList) { collectSuccessorTokens(token); } growTimer.stop(); if (logger.isLoggable(Level.FINE)) { int hmms = activeList.size(); totalHmms += hmms; logger.fine("Frame: " + currentFrameNumber + " Hmms: " + hmms + " total " + totalHmms); } } /** * Calculate the acoustic scores for the active list. The active list should contain only emitting tokens. * * @return <code>true</code> if there are more frames to score, otherwise, false */ protected boolean scoreTokens() { boolean hasMoreFrames = false; scoreTimer.start(); Data data = scorer.calculateScores(activeList.getTokens()); scoreTimer.stop(); Token bestToken = null; if (data instanceof Token) { bestToken = (Token)data; } else if (data == null) { streamEnd = true; } if (bestToken != null) { hasMoreFrames = true; currentCollectTime = bestToken.getCollectTime(); activeList.setBestToken(bestToken); } // update statistics curTokensScored.value += activeList.size(); totalTokensScored.value += activeList.size(); tokensPerSecond.value = totalTokensScored.value / getTotalTime(); // if (logger.isLoggable(Level.FINE)) { // logger.fine(currentFrameNumber + " " + activeList.size() // + " " + curTokensScored.value + " " // + (int) tokensPerSecond.value); // } return hasMoreFrames; } /** * Returns the total time since we start4ed * * @return the total time (in seconds) */ private double getTotalTime() { return (System.currentTimeMillis() - startTime) / 1000.0; } /** Removes unpromising branches from the active list */ protected void pruneBranches() { int startSize = activeList.size(); pruneTimer.start(); activeList = pruner.prune(activeList); beamPruned.value += startSize - activeList.size(); pruneTimer.stop(); } /** * Gets the best token for this state * * @param state the state of interest * @return the best token */ protected Token getBestToken(SearchState state) { Token best = bestTokenMap.get(state); if (logger.isLoggable(Level.FINER) && best != null) { logger.finer("BT " + best + " for state " + state); } return best; } /** * Sets the best token for a given state * * @param token the best token * @param state the state * @return the previous best token for the given state, or null if no previous best token */ protected Token setBestToken(Token token, SearchState state) { return bestTokenMap.put(state, token); } public ActiveList getActiveList() { return activeList; } /** * Collects the next set of emitting tokens from a token and accumulates them in the active or result lists * * @param token the token to collect successors from */ protected void collectSuccessorTokens(Token token) { SearchState state = token.getSearchState(); // If this is a final state, add it to the final list if (token.isFinal()) { resultList.add(token); } if (token.getScore() < threshold) { return; } if (state instanceof WordSearchState && token.getScore() < wordThreshold) { return; } SearchStateArc[] arcs = state.getSuccessors(); // For each successor // calculate the entry score for the token based upon the // predecessor token score and the transition probabilities // if the score is better than the best score encountered for // the SearchState and frame then create a new token, add // it to the lattice and the SearchState. // If the token is an emitting token add it to the list, // otherwise recursively collect the new tokens successors. for (SearchStateArc arc : arcs) { SearchState nextState = arc.getState(); // We're actually multiplying the variables, but since // these come in log(), multiply gets converted to add float logEntryScore = token.getScore() + arc.getProbability(); if (wantEntryPruning) { // false by default if (logEntryScore < threshold) { continue; } if (nextState instanceof WordSearchState && logEntryScore < wordThreshold) { continue; } } Token predecessor = getResultListPredecessor(token); // if not emitting, check to see if we've already visited // this state during this frame. Expand the token only if we // haven't visited it already. This prevents the search // from getting stuck in a loop of states with no // intervening emitting nodes. This can happen with nasty // jsgf grammars such as ((foo*)*)* if (!nextState.isEmitting()) { Token newToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), currentCollectTime); tokensCreated.value++; if (!isVisited(newToken)) { collectSuccessorTokens(newToken); } continue; } Token bestToken = getBestToken(nextState); if (bestToken == null) { Token newToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), currentFrameNumber); tokensCreated.value++; setBestToken(newToken, nextState); activeList.add(newToken); } else { if (bestToken.getScore() <= logEntryScore) { bestToken.update(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), currentCollectTime); viterbiPruned.value++; } else { viterbiPruned.value++; } } } } /** * Determines whether or not we've visited the state associated with this token since the previous frame. * * @param t the token to check * @return true if we've visited the search state since the last frame */ private boolean isVisited(Token t) { SearchState curState = t.getSearchState(); t = t.getPredecessor(); while (t != null && !t.isEmitting()) { if (curState.equals(t.getSearchState())) { return true; } t = t.getPredecessor(); } return false; } /** Counts all the tokens in the active list (and displays them). This is an expensive operation. */ protected void showTokenCount() { if (logger.isLoggable(Level.INFO)) { Set<Token> tokenSet = new HashSet<Token>(); for (Token token : activeList) { while (token != null) { tokenSet.add(token); token = token.getPredecessor(); } } logger.info("Token Lattice size: " + tokenSet.size()); tokenSet = new HashSet<Token>(); for (Token token : resultList) { while (token != null) { tokenSet.add(token); token = token.getPredecessor(); } } logger.info("Result Lattice size: " + tokenSet.size()); } } /** * Returns the best token map. * * @return the best token map */ protected Map<SearchState, Token> getBestTokenMap() { return bestTokenMap; } /** * Sets the best token Map. * * @param bestTokenMap the new best token Map */ protected void setBestTokenMap(Map<SearchState, Token> bestTokenMap) { this.bestTokenMap = bestTokenMap; } /** * Returns the result list. * * @return the result list */ public List<Token> getResultList() { return resultList; } /** * Returns the current frame number. * * @return the current frame number */ public int getCurrentFrameNumber() { return currentFrameNumber; } /** * Returns the Timer for growing. * * @return the Timer for growing */ public Timer getGrowTimer() { return growTimer; } /** * Returns the tokensCreated StatisticsVariable. * * @return the tokensCreated StatisticsVariable. */ public StatisticsVariable getTokensCreated() { return tokensCreated; } /* * (non-Javadoc) * * @see edu.cmu.sphinx.decoder.search.SearchManager#allocate() */ public void allocate() { totalTokensScored = StatisticsVariable .getStatisticsVariable("totalTokensScored"); tokensPerSecond = StatisticsVariable .getStatisticsVariable("tokensScoredPerSecond"); curTokensScored = StatisticsVariable .getStatisticsVariable("curTokensScored"); tokensCreated = StatisticsVariable .getStatisticsVariable("tokensCreated"); viterbiPruned = StatisticsVariable .getStatisticsVariable("viterbiPruned"); beamPruned = StatisticsVariable.getStatisticsVariable("beamPruned"); try { linguist.allocate(); pruner.allocate(); scorer.allocate(); } catch (IOException e) { throw new RuntimeException("Allocation of search manager resources failed", e); } scoreTimer = TimerPool.getTimer(this, "Score"); pruneTimer = TimerPool.getTimer(this, "Prune"); growTimer = TimerPool.getTimer(this, "Grow"); } /* * (non-Javadoc) * * @see edu.cmu.sphinx.decoder.search.SearchManager#deallocate() */ public void deallocate() { try { scorer.deallocate(); pruner.deallocate(); linguist.deallocate(); } catch (IOException e) { throw new RuntimeException("Deallocation of search manager resources failed", e); } } @Override public String toString() { return name; } }