/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.decoder.search; // a test search manager. import edu.cmu.sphinx.decoder.pruner.Pruner; import edu.cmu.sphinx.decoder.scorer.AcousticScorer; import edu.cmu.sphinx.frontend.Data; import edu.cmu.sphinx.linguist.*; import edu.cmu.sphinx.result.Result; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.StatisticsVariable; import edu.cmu.sphinx.util.Timer; import edu.cmu.sphinx.util.TimerPool; import edu.cmu.sphinx.util.props.*; import java.io.IOException; import java.util.*; import java.util.logging.Level; import java.util.logging.Logger; /** * Provides the breadth first search. To perform recognition an application * should call initialize before recognition begins, and repeatedly call * <code> recognize </code> until Result.isFinal() returns true. Once a final * result has been obtained, <code> stopRecognition </code> should be called. * <p> * All scores and probabilities are maintained in the log math log domain. */ public class WordPruningBreadthFirstSearchManager extends TokenSearchManager { /** * The property that defines the name of the linguist to be used by this * search manager. */ @S4Component(type = Linguist.class) public final static String PROP_LINGUIST = "linguist"; /** * The property that defines the name of the linguist to be used by this * search manager. */ @S4Component(type = Pruner.class) public final static String PROP_PRUNER = "pruner"; /** * The property that defines the name of the scorer to be used by this * search manager. */ @S4Component(type = AcousticScorer.class) public final static String PROP_SCORER = "scorer"; /** * The property than, when set to <code>true</code> will cause the * recognizer to count up all the tokens in the active list after every * frame. */ @S4Boolean(defaultValue = false) public final static String PROP_SHOW_TOKEN_COUNT = "showTokenCount"; /** * The property that controls the number of frames processed for every time * the decode growth step is skipped. Setting this property to zero disables * grow skipping. Setting this number to a small integer will increase the * speed of the decoder but will also decrease its accuracy. The higher the * number, the less often the grow code is skipped. Values like 6-8 is known * to be the good enough for large vocabulary tasks. That means that one of * 6 frames will be skipped. */ @S4Integer(defaultValue = 0) public final static String PROP_GROW_SKIP_INTERVAL = "growSkipInterval"; /** The property that defines the type of active list to use */ @S4Component(type = ActiveListManager.class) public final static String PROP_ACTIVE_LIST_MANAGER = "activeListManager"; /** The property for checking if the order of states is valid. */ @S4Boolean(defaultValue = false) public final static String PROP_CHECK_STATE_ORDER = "checkStateOrder"; /** The property that specifies the maximum lattice edges */ @S4Integer(defaultValue = 100) public final static String PROP_MAX_LATTICE_EDGES = "maxLatticeEdges"; /** * The property that controls the amount of simple acoustic lookahead * performed. Setting the property to zero (the default) disables simple * acoustic lookahead. The lookahead need not be an integer. */ @S4Double(defaultValue = 0) public final static String PROP_ACOUSTIC_LOOKAHEAD_FRAMES = "acousticLookaheadFrames"; /** The property that specifies the relative beam width */ @S4Double(defaultValue = 0.0) // TODO: this should be a more meaningful default e.g. the common 1E-80 public final static String PROP_RELATIVE_BEAM_WIDTH = "relativeBeamWidth"; // ----------------------------------- // Configured Subcomponents // ----------------------------------- protected Linguist linguist; // Provides grammar/language info protected Pruner pruner; // used to prune the active list protected AcousticScorer scorer; // used to score the active list private ActiveListManager activeListManager; protected LogMath logMath; // ----------------------------------- // Configuration data // ----------------------------------- protected Logger logger; protected boolean showTokenCount; protected boolean checkStateOrder; private int growSkipInterval; protected float relativeBeamWidth; protected float acousticLookaheadFrames; private int maxLatticeEdges = 100; // ----------------------------------- // Instrumentation // ----------------------------------- protected Timer scoreTimer; protected Timer pruneTimer; protected Timer growTimer; protected StatisticsVariable totalTokensScored; protected StatisticsVariable curTokensScored; protected StatisticsVariable tokensCreated; private long tokenSum; private int tokenCount; // ----------------------------------- // Working data // ----------------------------------- protected int currentFrameNumber; // the current frame number protected long currentCollectTime; // the current frame number protected ActiveList activeList; // the list of active tokens protected List<Token> resultList; // the current set of results protected Map<SearchState, Token> bestTokenMap; protected AlternateHypothesisManager loserManager; private int numStateOrder; // private TokenTracker tokenTracker; // private TokenTypeTracker tokenTypeTracker; protected boolean streamEnd; /** * Creates a pruning manager withs separate lists for tokens * @param linguist a linguist for search space * @param pruner pruner to drop tokens * @param scorer scorer to estimate token probability * @param activeListManager active list manager to store tokens * @param showTokenCount show count during decoding * @param relativeWordBeamWidth relative beam for lookahead pruning * @param growSkipInterval skip interval for grown * @param checkStateOrder check order of states during growth * @param buildWordLattice build a lattice during decoding * @param maxLatticeEdges max edges to keep in lattice * @param acousticLookaheadFrames frames to do lookahead * @param keepAllTokens keep tokens including emitting tokens */ public WordPruningBreadthFirstSearchManager(Linguist linguist, Pruner pruner, AcousticScorer scorer, ActiveListManager activeListManager, boolean showTokenCount, double relativeWordBeamWidth, int growSkipInterval, boolean checkStateOrder, boolean buildWordLattice, int maxLatticeEdges, float acousticLookaheadFrames, boolean keepAllTokens) { this.logger = Logger.getLogger(getClass().getName()); this.logMath = LogMath.getLogMath(); this.linguist = linguist; this.pruner = pruner; this.scorer = scorer; this.activeListManager = activeListManager; this.showTokenCount = showTokenCount; this.growSkipInterval = growSkipInterval; this.checkStateOrder = checkStateOrder; this.buildWordLattice = buildWordLattice; this.maxLatticeEdges = maxLatticeEdges; this.acousticLookaheadFrames = acousticLookaheadFrames; this.keepAllTokens = keepAllTokens; this.relativeBeamWidth = logMath.linearToLog(relativeWordBeamWidth); } public WordPruningBreadthFirstSearchManager() { } /* * (non-Javadoc) * * @see * edu.cmu.sphinx.util.props.Configurable#newProperties(edu.cmu.sphinx.util * .props.PropertySheet) */ @Override public void newProperties(PropertySheet ps) throws PropertyException { super.newProperties(ps); logMath = LogMath.getLogMath(); logger = ps.getLogger(); linguist = (Linguist) ps.getComponent(PROP_LINGUIST); pruner = (Pruner) ps.getComponent(PROP_PRUNER); scorer = (AcousticScorer) ps.getComponent(PROP_SCORER); activeListManager = (ActiveListManager) ps.getComponent(PROP_ACTIVE_LIST_MANAGER); showTokenCount = ps.getBoolean(PROP_SHOW_TOKEN_COUNT); growSkipInterval = ps.getInt(PROP_GROW_SKIP_INTERVAL); checkStateOrder = ps.getBoolean(PROP_CHECK_STATE_ORDER); maxLatticeEdges = ps.getInt(PROP_MAX_LATTICE_EDGES); acousticLookaheadFrames = ps.getFloat(PROP_ACOUSTIC_LOOKAHEAD_FRAMES); relativeBeamWidth = logMath.linearToLog(ps.getDouble(PROP_RELATIVE_BEAM_WIDTH)); } /* * (non-Javadoc) * * @see edu.cmu.sphinx.decoder.search.SearchManager#allocate() */ public void allocate() { // tokenTracker = new TokenTracker(); // tokenTypeTracker = new TokenTypeTracker(); scoreTimer = TimerPool.getTimer(this, "Score"); pruneTimer = TimerPool.getTimer(this, "Prune"); growTimer = TimerPool.getTimer(this, "Grow"); totalTokensScored = StatisticsVariable.getStatisticsVariable("totalTokensScored"); curTokensScored = StatisticsVariable.getStatisticsVariable("curTokensScored"); tokensCreated = StatisticsVariable.getStatisticsVariable("tokensCreated"); try { linguist.allocate(); pruner.allocate(); scorer.allocate(); } catch (IOException e) { throw new RuntimeException("Allocation of search manager resources failed", e); } } /* * (non-Javadoc) * * @see edu.cmu.sphinx.decoder.search.SearchManager#deallocate() */ public void deallocate() { try { scorer.deallocate(); pruner.deallocate(); linguist.deallocate(); } catch (IOException e) { throw new RuntimeException("Deallocation of search manager resources failed", e); } } /** * Called at the start of recognition. Gets the search manager ready to * recognize */ public void startRecognition() { linguist.startRecognition(); pruner.startRecognition(); scorer.startRecognition(); localStart(); } /** * Performs the recognition for the given number of frames. * * @param nFrames * the number of frames to recognize * @return the current result */ public Result recognize(int nFrames) { boolean done = false; Result result = null; streamEnd = false; for (int i = 0; i < nFrames && !done; i++) { done = recognize(); } if (!streamEnd) { result = new Result(loserManager, activeList, resultList, currentCollectTime, done, linguist.getSearchGraph() .getWordTokenFirst(), true); } // tokenTypeTracker.show(); if (showTokenCount) { showTokenCount(); } return result; } protected boolean recognize() { activeList = activeListManager.getEmittingList(); boolean more = scoreTokens(); if (more) { pruneBranches(); currentFrameNumber++; if (growSkipInterval == 0 || (currentFrameNumber % growSkipInterval) != 0) { clearCollectors(); growEmittingBranches(); growNonEmittingBranches(); } } return !more; } /** * Clears lists and maps before next expansion stage */ private void clearCollectors() { resultList = new LinkedList<Token>(); createBestTokenMap(); activeListManager.clearEmittingList(); } /** * creates a new best token map with the best size */ protected void createBestTokenMap() { int mapSize = activeList.size() * 10; if (mapSize == 0) { mapSize = 1; } bestTokenMap = new HashMap<SearchState, Token>(mapSize, 0.3F); } /** Terminates a recognition */ public void stopRecognition() { localStop(); scorer.stopRecognition(); pruner.stopRecognition(); linguist.stopRecognition(); } /** * Gets the initial grammar node from the linguist and creates a * GrammarNodeToken */ protected void localStart() { SearchGraph searchGraph = linguist.getSearchGraph(); currentFrameNumber = 0; curTokensScored.value = 0; numStateOrder = searchGraph.getNumStateOrder(); activeListManager.setNumStateOrder(numStateOrder); if (buildWordLattice) { loserManager = new AlternateHypothesisManager(maxLatticeEdges); } SearchState state = searchGraph.getInitialState(); activeList = activeListManager.getEmittingList(); activeList.add(new Token(state, -1)); clearCollectors(); growBranches(); growNonEmittingBranches(); // tokenTracker.setEnabled(false); // tokenTracker.startUtterance(); } /** Local cleanup for this search manager */ protected void localStop() { // tokenTracker.stopUtterance(); } /** * Goes through the active list of tokens and expands each token, finding * the set of successor tokens until all the successor tokens are emitting * tokens. */ protected void growBranches() { growTimer.start(); float relativeBeamThreshold = activeList.getBeamThreshold(); if (logger.isLoggable(Level.FINE)) { logger.fine("Frame: " + currentFrameNumber + " thresh : " + relativeBeamThreshold + " bs " + activeList.getBestScore() + " tok " + activeList.getBestToken()); } for (Token token : activeList) { if (token.getScore() >= relativeBeamThreshold && allowExpansion(token)) { collectSuccessorTokens(token); } } growTimer.stop(); } /** * Grows the emitting branches. This version applies a simple acoustic * lookahead based upon the rate of change in the current acoustic score. */ protected void growEmittingBranches() { if (acousticLookaheadFrames <= 0.0f) { growBranches(); return; } growTimer.start(); float bestScore = -Float.MAX_VALUE; for (Token t : activeList) { float score = t.getScore() + t.getAcousticScore() * acousticLookaheadFrames; if (score > bestScore) { bestScore = score; } } float relativeBeamThreshold = bestScore + relativeBeamWidth; for (Token t : activeList) { if (t.getScore() + t.getAcousticScore() * acousticLookaheadFrames > relativeBeamThreshold) collectSuccessorTokens(t); } growTimer.stop(); } /** * Grow the non-emitting branches, until the tokens reach an emitting state. */ private void growNonEmittingBranches() { for (Iterator<ActiveList> i = activeListManager.getNonEmittingListIterator(); i.hasNext();) { activeList = i.next(); if (activeList != null) { i.remove(); pruneBranches(); growBranches(); } } } /** * Calculate the acoustic scores for the active list. The active list should * contain only emitting tokens. * * @return <code>true</code> if there are more frames to score, otherwise, * false */ protected boolean scoreTokens() { boolean moreTokens; scoreTimer.start(); Data data = scorer.calculateScores(activeList.getTokens()); scoreTimer.stop(); Token bestToken = null; if (data instanceof Token) { bestToken = (Token) data; } else if (data == null) { streamEnd = true; } if (bestToken != null) { currentCollectTime = bestToken.getCollectTime(); } moreTokens = (bestToken != null); activeList.setBestToken(bestToken); // monitorWords(activeList); monitorStates(activeList); // System.out.println("BEST " + bestToken); curTokensScored.value += activeList.size(); totalTokensScored.value += activeList.size(); return moreTokens; } /** * Keeps track of and reports all of the active word histories for the given * active list * * @param activeList * the active list to track */ @SuppressWarnings("unused") private void monitorWords(ActiveList activeList) { // WordTracker tracker1 = new WordTracker(currentFrameNumber); // // for (Token t : activeList) { // tracker1.add(t); // } // tracker1.dump(); // // TokenTracker tracker2 = new TokenTracker(); // // for (Token t : activeList) { // tracker2.add(t); // } // tracker2.dumpSummary(); // tracker2.dumpDetails(); // // TokenTypeTracker tracker3 = new TokenTypeTracker(); // // for (Token t : activeList) { // tracker3.add(t); // } // tracker3.dump(); // StateHistoryTracker tracker4 = new // StateHistoryTracker(currentFrameNumber); // for (Token t : activeList) { // tracker4.add(t); // } // tracker4.dump(); } /** * Keeps track of and reports statistics about the number of active states * * @param activeList * the active list of states */ protected void monitorStates(ActiveList activeList) { tokenSum += activeList.size(); tokenCount++; if ((tokenCount % 1000) == 0) { logger.info("Average Tokens/State: " + (tokenSum / tokenCount)); } } /** Removes unpromising branches from the active list */ protected void pruneBranches() { pruneTimer.start(); activeList = pruner.prune(activeList); pruneTimer.stop(); } /** * Gets the best token for this state * * @param state * the state of interest * @return the best token */ protected Token getBestToken(SearchState state) { return bestTokenMap.get(state); } /** * Sets the best token for a given state * * @param token * the best token * @param state * the state */ protected void setBestToken(Token token, SearchState state) { bestTokenMap.put(state, token); } /** * Checks that the given two states are in legitimate order. * * @param fromState parent state * @param toState child state */ protected void checkStateOrder(SearchState fromState, SearchState toState) { if (fromState.getOrder() == numStateOrder - 1) { return; } if (fromState.getOrder() > toState.getOrder()) { throw new Error("IllegalState order: from " + fromState.getClass().getName() + ' ' + fromState.toPrettyString() + " order: " + fromState.getOrder() + " to " + toState.getClass().getName() + ' ' + toState.toPrettyString() + " order: " + toState.getOrder()); } } /** * Collects the next set of emitting tokens from a token and accumulates * them in the active or result lists * * @param token * the token to collect successors from be immediately expanded * are placed. Null if we should always expand all nodes. */ protected void collectSuccessorTokens(Token token) { // tokenTracker.add(token); // tokenTypeTracker.add(token); // If this is a final state, add it to the final list if (token.isFinal()) { resultList.add(getResultListPredecessor(token)); return; } // if this is a non-emitting token and we've already // visited the same state during this frame, then we // are in a grammar loop, so we don't continue to expand. // This check only works properly if we have kept all of the // tokens (instead of skipping the non-word tokens). // Note that certain linguists will never generate grammar loops // (lextree linguist for example). For these cases, it is perfectly // fine to disable this check by setting keepAllTokens to false if (!token.isEmitting() && (keepAllTokens && isVisited(token))) { return; } SearchState state = token.getSearchState(); SearchStateArc[] arcs = state.getSuccessors(); Token predecessor = getResultListPredecessor(token); // For each successor // calculate the entry score for the token based upon the // predecessor token score and the transition probabilities // if the score is better than the best score encountered for // the SearchState and frame then create a new token, add // it to the lattice and the SearchState. // If the token is an emitting token add it to the list, // otherwise recursively collect the new tokens successors. for (SearchStateArc arc : arcs) { SearchState nextState = arc.getState(); if (checkStateOrder) { checkStateOrder(state, nextState); } // We're actually multiplying the variables, but since // these come in log(), multiply gets converted to add float logEntryScore = token.getScore() + arc.getProbability(); Token bestToken = getBestToken(nextState); if (bestToken == null) { Token newBestToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), currentCollectTime); tokensCreated.value++; setBestToken(newBestToken, nextState); activeListAdd(newBestToken); } else if (bestToken.getScore() < logEntryScore) { // System.out.println("Updating " + bestToken + " with " + // newBestToken); Token oldPredecessor = bestToken.getPredecessor(); bestToken.update(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), currentCollectTime); if (buildWordLattice && nextState instanceof WordSearchState) { loserManager.addAlternatePredecessor(bestToken, oldPredecessor); } } else if (buildWordLattice && nextState instanceof WordSearchState) { if (predecessor != null) { loserManager.addAlternatePredecessor(bestToken, predecessor); } } } } /** * Determines whether or not we've visited the state associated with this * token since the previous frame. * * @param t token to check * @return true if we've visited the search state since the last frame */ protected boolean isVisited(Token t) { SearchState curState = t.getSearchState(); t = t.getPredecessor(); while (t != null && !t.isEmitting()) { if (curState.equals(t.getSearchState())) { System.out.println("CS " + curState + " match " + t.getSearchState()); return true; } t = t.getPredecessor(); } return false; } protected void activeListAdd(Token token) { activeListManager.add(token); } /** * Determine if the given token should be expanded * * @param t * the token to test * @return <code>true</code> if the token should be expanded */ protected boolean allowExpansion(Token t) { return true; // currently disabled } /** * Counts all the tokens in the active list (and displays them). This is an * expensive operation. */ protected void showTokenCount() { Set<Token> tokenSet = new HashSet<Token>(); for (Token token : activeList) { while (token != null) { tokenSet.add(token); token = token.getPredecessor(); } } System.out.println("Token Lattice size: " + tokenSet.size()); tokenSet = new HashSet<Token>(); for (Token token : resultList) { while (token != null) { tokenSet.add(token); token = token.getPredecessor(); } } System.out.println("Result Lattice size: " + tokenSet.size()); } /** * Returns the ActiveList. * * @return the ActiveList */ public ActiveList getActiveList() { return activeList; } /** * Sets the ActiveList. * * @param activeList * the new ActiveList */ public void setActiveList(ActiveList activeList) { this.activeList = activeList; } /** * Returns the result list. * * @return the result list */ public List<Token> getResultList() { return resultList; } /** * Sets the result list. * * @param resultList * the new result list */ public void setResultList(List<Token> resultList) { this.resultList = resultList; } /** * Returns the current frame number. * * @return the current frame number */ public int getCurrentFrameNumber() { return currentFrameNumber; } /** * Returns the Timer for growing. * * @return the Timer for growing */ public Timer getGrowTimer() { return growTimer; } /** * Returns the tokensCreated StatisticsVariable. * * @return the tokensCreated StatisticsVariable. */ public StatisticsVariable getTokensCreated() { return tokensCreated; } }