/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer; import edu.cmu.sphinx.frontend.FloatData; import edu.cmu.sphinx.linguist.acoustic.HMMState; import edu.cmu.sphinx.linguist.acoustic.HMM; import edu.cmu.sphinx.linguist.acoustic.tiedstate.*; import edu.cmu.sphinx.util.LogMath; import java.io.IOException; import java.util.HashMap; import java.util.logging.Logger; /** Manages the HMM pools. */ class HMMPoolManager { private HMMManager hmmManager; private HashMap<Object, Integer> indexMap; private Pool<float[]> meansPool; private Pool<float[]> variancePool; private Pool<float[][]> matrixPool; private GaussianWeights mixtureWeights; private Pool<Buffer> meansBufferPool; private Pool<Buffer> varianceBufferPool; private Pool<Buffer[]> matrixBufferPool; private Pool<Buffer> mixtureWeightsBufferPool; private Pool<Senone> senonePool; private LogMath logMath; private float logMixtureWeightFloor; private float logTransitionProbabilityFloor; private float varianceFloor; private float logLikelihood; private float currentLogLikelihood; /** The logger for this class */ private static Logger logger = Logger.getLogger("edu.cmu.sphinx.linguist.acoustic.HMMPoolManager"); /** * Constructor for this pool manager. It gets the pointers to the pools from a loader. * * @param loader the loader * @throws IOException */ protected HMMPoolManager(Loader loader) throws IOException { loader.load(); hmmManager = loader.getHMMManager(); indexMap = new HashMap<Object, Integer>(); meansPool = loader.getMeansPool(); variancePool = loader.getVariancePool(); mixtureWeights = loader.getMixtureWeights(); matrixPool = loader.getTransitionMatrixPool(); senonePool = loader.getSenonePool(); // logMath = LogMath.getLogMath(); // float mixtureWeightFloor = // props.getFloat(TiedStateAcousticModel.PROP_MW_FLOOR); // logMixtureWeightFloor = logMath.linearToLog(mixtureWeightFloor); // float transitionProbabilityFloor = // props.getFloat(TiedStateAcousticModel.PROP_TP_FLOOR); // logTransitionProbabilityFloor = // logMath.linearToLog(transitionProbabilityFloor); // varianceFloor = // props.getFloat(TiedStateAcousticModel.PROP_VARIANCE_FLOOR); createBuffers(); logLikelihood = 0.0f; logMath = LogMath.getLogMath(); } /** Recreates the buffers. */ protected void resetBuffers() { createBuffers(); logLikelihood = 0.0f; } /** Create buffers for all pools used by the trainer in this pool manager. */ protected void createBuffers() { // the option false or true refers to whether the buffer is in // log scale or not, true if it is. meansBufferPool = create1DPoolBuffer(meansPool, false); varianceBufferPool = create1DPoolBuffer(variancePool, false); matrixBufferPool = create2DPoolBuffer(matrixPool, true); mixtureWeightsBufferPool = createWeightsPoolBuffer(mixtureWeights); } /** Create buffers for a given pool. */ private Pool<Buffer> create1DPoolBuffer(Pool<float[]> pool, boolean isLog) { Pool<Buffer> bufferPool = new Pool<Buffer>(pool.getName()); for (int i = 0; i < pool.size(); i++) { float[] element = pool.get(i); indexMap.put(element, i); Buffer buffer = new Buffer(element.length, isLog, i); bufferPool.put(i, buffer); } return bufferPool; } private Pool<Buffer> createWeightsPoolBuffer(GaussianWeights mixtureWeights) { Pool<Buffer> bufferPool = new Pool<Buffer>(mixtureWeights.getName()); int statesNum = mixtureWeights.getStatesNum(); int streamsNum = mixtureWeights.getStreamsNum(); int gauPerState = mixtureWeights.getGauPerState(); for (int i = 0; i < streamsNum; i++) { for (int j = 0; j < statesNum; j++) { int id = i * statesNum + j; Buffer buffer = new Buffer(gauPerState, true, id); bufferPool.put(id, buffer); } } return bufferPool; } /** Create buffers for a given pool. */ private Pool<Buffer[]> create2DPoolBuffer(Pool<float[][]> pool, boolean isLog) { Pool<Buffer[]> bufferPool = new Pool<Buffer[]>(pool.getName()); for (int i = 0; i < pool.size(); i++) { float[][] element = pool.get(i); indexMap.put(element, i); int poolSize = element.length; Buffer[] bufferArray = new Buffer[poolSize]; for (int j = 0; j < poolSize; j++) { bufferArray[j] = new Buffer(element[j].length, isLog, j); } bufferPool.put(i, bufferArray); } return bufferPool; } /** * Accumulate the TrainerScore into the buffers. * * @param index the current index into the TrainerScore vector * @param score the TrainerScore */ protected void accumulate(int index, TrainerScore[] score) { accumulate(index, score, null); } /** * Accumulate the TrainerScore into the buffers. * * @param index the current index into the TrainerScore vector * @param score the TrainerScore for the current frame * @param nextScore the TrainerScore for the next time frame */ protected void accumulate(int index, TrainerScore[] score, TrainerScore[] nextScore) { int senoneID; TrainerScore thisScore = score[index]; // We should be doing this just once per utterance... // currentLogLikelihood = thisScore.getLogLikelihood(); // Since we're scaling, the loglikelihood disappears... currentLogLikelihood = 0; // And the total becomes the sum of (-) scaling factors logLikelihood -= score[0].getScalingFactor(); SenoneHMMState state = (SenoneHMMState) thisScore.getState(); if (state == null) { // We only care about the case "all models" senoneID = thisScore.getSenoneID(); if (senoneID == TrainerAcousticModel.ALL_MODELS) { accumulateMean(senoneID, score[index]); accumulateVariance(senoneID, score[index]); accumulateMixture(senoneID, score[index]); accumulateTransition(senoneID, index, score, nextScore); } } else { // If state is non-emitting, we presume there's only one // transition out of it. Therefore, we only accumulate // data for emitting states. if (state.isEmitting()) { senoneID = senonePool.indexOf(state.getSenone()); // accumulateMean(senoneID, score[index]); // accumulateVariance(senoneID, score[index]); accumulateMixture(senoneID, score[index]); accumulateTransition(senoneID, index, score, nextScore); } } } /** Accumulate the means. */ private void accumulateMean(int senone, TrainerScore score) { if (senone == TrainerAcousticModel.ALL_MODELS) { for (int i = 0; i < senonePool.size(); i++) { accumulateMean(i, score); } } else { GaussianMixture gaussian = (GaussianMixture)senonePool.get(senone); MixtureComponent[] mix = gaussian.getMixtureComponents(); for (int i = 0; i < mix.length; i++) { float[] mean = mix[i].getMean(); // int indexMean = meansPool.indexOf(mean); int indexMean = indexMap.get(mean); assert indexMean >= 0; assert indexMean == senone; Buffer buffer = meansBufferPool.get(indexMean); float[] feature = ((FloatData) score.getData()).getValues(); double[] data = new double[feature.length]; float prob = score.getComponentGamma()[i]; prob -= currentLogLikelihood; double dprob = logMath.logToLinear(prob); // prob = (float) logMath.logToLinear(prob); for (int j = 0; j < data.length; j++) { data[j] = feature[j] * dprob; } buffer.accumulate(data, dprob); } } } /** Accumulate the variance. */ private void accumulateVariance(int senone, TrainerScore score) { if (senone == TrainerAcousticModel.ALL_MODELS) { for (int i = 0; i < senonePool.size(); i++) { accumulateVariance(i, score); } } else { GaussianMixture gaussian = (GaussianMixture)senonePool.get(senone); MixtureComponent[] mix = gaussian.getMixtureComponents(); for (int i = 0; i < mix.length; i++) { float[] mean = mix[i].getMean(); float[] variance = mix[i].getVariance(); // int indexVariance = variancePool.indexOf(variance); int indexVariance = indexMap.get(variance); Buffer buffer = varianceBufferPool.get(indexVariance); float[] feature = ((FloatData) score.getData()).getValues(); double[] data = new double[feature.length]; float prob = score.getComponentGamma()[i]; prob -= currentLogLikelihood; double dprob = logMath.logToLinear(prob); for (int j = 0; j < data.length; j++) { data[j] = (feature[j] - mean[j]); data[j] *= data[j] * dprob; } buffer.accumulate(data, dprob); } } } /** Accumulate the mixture weights. */ private void accumulateMixture(int senone, TrainerScore score) { // The index into the senone pool and the mixture weight pool // is the same if (senone == TrainerAcousticModel.ALL_MODELS) { for (int i = 0; i < senonePool.size(); i++) { accumulateMixture(i, score); } } else { Buffer buffer = mixtureWeightsBufferPool.get(senone); for (int i = 0; i < mixtureWeights.getGauPerState(); i++) { float prob = score.getComponentGamma()[i]; prob -= currentLogLikelihood; buffer.logAccumulate(prob, i, logMath); } } } /** * Accumulate transitions from a given state. * * @param indexScore the current index into the TrainerScore * @param score the score information * @param nextScore the score information for the next frame */ private void accumulateStateTransition(int indexScore, TrainerScore[] score, TrainerScore[] nextScore) { HMMState state = score[indexScore].getState(); if (state == null) { // Non-emitting state return; } int indexState = state.getState(); SenoneHMM hmm = (SenoneHMM) state.getHMM(); float[][] matrix = hmm.getTransitionMatrix(); // Find the index for current matrix in the transition matrix pool // int indexMatrix = matrixPool.indexOf(matrix); int indexMatrix = indexMap.get(matrix); // Find the corresponding buffer Buffer[] bufferArray = matrixBufferPool.get(indexMatrix); // Let's concentrate on the transitions *from* the current state float[] vector = matrix[indexState]; for (int i = 0; i < vector.length; i++) { // Make sure this is a valid transition if (vector[i] != LogMath.LOG_ZERO) { // We're assuming that if the states have position "a" // and "b" in the HMM, they'll have positions "k+a" // and "k+b" in the graph, that is, their relative // position is the same. // Distance between current state and "to" state in // the HMM int dist = i - indexState; // "to" state in the graph int indexNextScore = indexScore + dist; // Make sure the next state is non-emitting (the last // in the HMM), or in the same HMM. assert ((nextScore[indexNextScore].getState() == null) || (nextScore[indexNextScore].getState().getHMM() == hmm)); float alpha = score[indexScore].getAlpha(); float beta = nextScore[indexNextScore].getBeta(); float transitionProb = vector[i]; float outputProb = nextScore[indexNextScore].getScore(); float prob = alpha + beta + transitionProb + outputProb; prob -= currentLogLikelihood; // i is the index into the next state. bufferArray[indexState].logAccumulate(prob, i, logMath); /* if ((indexMatrix == 0) && (i == 2)) { // System.out.println("Out: " + outputProb); // bufferArray[indexState].dump(); } */ } } } /** * Accumulate transitions from a given state. * * @param indexState the state index * @param hmm the HMM * @param value the value to accumulate */ private void accumulateStateTransition(int indexState, SenoneHMM hmm, float value) { // Find the transition matrix in this hmm float[][] matrix = hmm.getTransitionMatrix(); // Find the vector with transitions from the current state to // other states. float[] stateVector = matrix[indexState]; // Find the index of the current transition matrix in the // transition matrix pool. // int indexMatrix = matrixPool.indexOf(matrix); int indexMatrix = indexMap.get(matrix); // Find the buffer for the transition matrix. Buffer[] bufferArray = matrixBufferPool.get(indexMatrix); // Accumulate for the transitions from current state for (int i = 0; i < stateVector.length; i++) { // Make sure we're not trying to accumulate in an invalid // transition. if (stateVector[i] != LogMath.LOG_ZERO) { bufferArray[indexState].logAccumulate(value, i, logMath); } } } /** Accumulate the transition probabilities. */ private void accumulateTransition(int indexHmm, int indexScore, TrainerScore[] score, TrainerScore[] nextScore) { if (indexHmm == TrainerAcousticModel.ALL_MODELS) { // Well, special case... we want to add an amount to all // the states in all models for (HMM hmm : hmmManager) { for (int j = 0; j < hmm.getOrder(); j++) { accumulateStateTransition(j, (SenoneHMM)hmm, score[indexScore].getScore()); } } } else { // For transition accumulation, we don't consider the last // time frame, since there's no transition from there to // anywhere... if (nextScore != null) { accumulateStateTransition(indexScore, score, nextScore); } } } /** Update the log likelihood. This method should be called for every utterance. */ protected void updateLogLikelihood() { // logLikelihood += currentLogLikelihood; } /** * Normalize the buffers. * * @return the log likelihood associated with the current training set */ protected float normalize() { normalizePool(meansBufferPool); normalizePool(varianceBufferPool); logNormalizePool(mixtureWeightsBufferPool); logNormalize2DPool(matrixBufferPool, matrixPool); return logLikelihood; } /** * Normalize a single buffer pool. * * @param pool the buffer pool to normalize */ private void normalizePool(Pool<Buffer> pool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer buffer = pool.get(i); if (buffer.wasUsed()) { buffer.normalize(); } } } /** * Normalize a single buffer pool in log scale. * * @param pool the buffer pool to normalize */ private void logNormalizePool(Pool<Buffer> pool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer buffer = pool.get(i); if (buffer.wasUsed()) { buffer.logNormalize(); } } } /** * Normalize a 2D buffer pool in log scale. Typically, this is the case with the transition matrix, which also needs * a mask for values that are allowed, and therefor have to be updated, or not allowed, and should be ignored. * * @param pool the buffer pool to normalize * @param maskPool pool containing a mask with zero/non-zero values. */ private void logNormalize2DPool(Pool<Buffer[]> pool, Pool<float[][]> maskPool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer[] bufferArray = pool.get(i); float[][] mask = maskPool.get(i); for (int j = 0; j < bufferArray.length; j++) { if (bufferArray[j].wasUsed()) { bufferArray[j].logNormalizeNonZero(mask[j]); } } } } /** Update the models. */ protected void update() { updateMeans(); updateVariances(); recomputeMixtureComponents(); updateMixtureWeights(); updateTransitionMatrices(); } /** * Copy one vector onto another. * * @param in the source vector * @param out the destination vector */ private void copyVector(float[] in, float[] out) { assert in.length == out.length; System.arraycopy(in, 0, out, 0, in.length); } /** Update the means. */ private void updateMeans() { assert meansPool.size() == meansBufferPool.size(); for (int i = 0; i < meansPool.size(); i++) { float[] means = meansPool.get(i); Buffer buffer = meansBufferPool.get(i); if (buffer.wasUsed()) { float[] meansBuffer = buffer.getValues(); copyVector(meansBuffer, means); } else { logger.info("Senone " + i + " not used."); } } } /** Update the variances. */ private void updateVariances() { assert variancePool.size() == varianceBufferPool.size(); for (int i = 0; i < variancePool.size(); i++) { float[] means = meansPool.get(i); float[] variance = variancePool.get(i); Buffer buffer = varianceBufferPool.get(i); if (buffer.wasUsed()) { float[] varianceBuffer = buffer.getValues(); assert means.length == varianceBuffer.length; for (int j = 0; j < means.length; j++) { varianceBuffer[j] -= means[j] * means[j]; if (varianceBuffer[j] < varianceFloor) { varianceBuffer[j] = varianceFloor; } } copyVector(varianceBuffer, variance); } } } /** Recompute the precomputed values in all mixture components. */ private void recomputeMixtureComponents() { for (int i = 0; i < senonePool.size(); i++) { GaussianMixture gMix = (GaussianMixture) senonePool.get(i); MixtureComponent[] mixComponent = gMix.getMixtureComponents(); for (MixtureComponent component : mixComponent) { component.precomputeDistance(); } } } /** Update the mixture weights. */ private void updateMixtureWeights() { int statesNum = mixtureWeights.getStatesNum(); int streamsNum = mixtureWeights.getStreamsNum(); assert statesNum * streamsNum == mixtureWeightsBufferPool.size(); for (int i = 0; i < streamsNum; i++) { for (int j = 0; j < statesNum; j++) { int id = i * statesNum + j; Buffer buffer = mixtureWeightsBufferPool.get(id); if (buffer.wasUsed()) { if (buffer.logFloor(logMixtureWeightFloor)) { buffer.logNormalizeToSum(logMath); } float[] mixtureWeightsBuffer = buffer.getValues(); mixtureWeights.put(j, i, mixtureWeightsBuffer); } } } } /** Update the transition matrices. */ private void updateTransitionMatrices() { assert matrixPool.size() == matrixBufferPool.size(); for (int i = 0; i < matrixPool.size(); i++) { float[][] matrix = matrixPool.get(i); Buffer[] bufferArray = matrixBufferPool.get(i); for (int j = 0; j < matrix.length; j++) { Buffer buffer = bufferArray[j]; if (buffer.wasUsed()) { for (int k = 0; k < matrix[j].length; k++) { float bufferValue = buffer.getValue(k); if (bufferValue != LogMath.LOG_ZERO) { assert matrix[j][k] != LogMath.LOG_ZERO; if (bufferValue < logTransitionProbabilityFloor) { buffer.setValue(k, logTransitionProbabilityFloor); } } } buffer.logNormalizeToSum(logMath); copyVector(buffer.getValues(), matrix[j]); } } } } }