/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer; import edu.cmu.sphinx.decoder.adaptation.ClusteredDensityFileData; import edu.cmu.sphinx.decoder.adaptation.Transform; import edu.cmu.sphinx.linguist.acoustic.HMM; import edu.cmu.sphinx.linguist.acoustic.HMMPosition; import edu.cmu.sphinx.linguist.acoustic.Unit; import edu.cmu.sphinx.linguist.acoustic.UnitManager; import edu.cmu.sphinx.linguist.acoustic.tiedstate.*; import static edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool.Feature.*; import edu.cmu.sphinx.util.ExtendedStreamTokenizer; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.StreamFactory; import edu.cmu.sphinx.util.props.*; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.StreamCorruptedException; import java.util.LinkedHashMap; import java.util.Map; import java.util.Properties; import java.util.logging.Level; import java.util.logging.Logger; /** * an acoustic model loader that initializes models * <p> * Mixture weights and transition probabilities are maintained in logMath log base, */ public class ModelInitializerLoader implements Loader { private final static String SILENCE_CIPHONE = "SIL"; public final static String MODEL_VERSION = "0.3"; private final static int CONTEXT_SIZE = 1; private Pool<float[]> meansPool; private Pool<float[]> variancePool; private Pool<float[][]> matrixPool; private Pool<float[][]> meanTransformationMatrixPool; private Pool<float[]> meanTransformationVectorPool; private Pool<float[][]> varianceTransformationMatrixPool; private Pool<float[]> varianceTransformationVectorPool; private GaussianWeights mixtureWeights; private Pool<Senone> senonePool; private int vectorLength = 39; private Map<String, Unit> contextIndependentUnits; private Map<String, Integer> phoneList; private HMMManager hmmManager; @S4String(defaultValue="model") public static final String MODEL_NAME = "modelName"; @S4String(defaultValue = ".") public static final String LOCATION = "location"; @S4String(defaultValue = "phonelist") public static final String PHONE_LIST = "phones"; @S4String(defaultValue = "data") public static final String DATA_DIR = "dataDir"; @S4String(defaultValue = "model.props") public static final String PROP_FILE = "propsFile"; @S4Component(type = UnitManager.class) public final static String PROP_UNIT_MANAGER = "unitManager"; private UnitManager unitManager; @S4Boolean(defaultValue = false) public final static String PROP_USE_CD_UNITS = "useCDUnits"; @S4Double(defaultValue = 0.0001f) public final static String PROP_VARIANCE_FLOOR = "varianceFloor"; /** Mixture component score floor. */ @S4Double(defaultValue = 0.0) public final static String PROP_MC_FLOOR = "MixtureComponentScoreFloor"; /** Mixture weight floor. */ @S4Double(defaultValue = 1e-7f) public final static String PROP_MW_FLOOR = "mixtureWeightFloor"; private LogMath logMath; /** The logger for this class */ private Logger logger; public void newProperties(PropertySheet ps) throws PropertyException { logMath = LogMath.getLogMath(); logger = ps.getLogger(); unitManager = (UnitManager) ps.getComponent(PROP_UNIT_MANAGER); hmmManager = new HMMManager(); contextIndependentUnits = new LinkedHashMap<String, Unit>(); phoneList = new LinkedHashMap<String, Integer>(); meanTransformationMatrixPool = createDummyMatrixPool("meanTransformationMatrix"); meanTransformationVectorPool = createDummyVectorPool("meanTransformationMatrix"); varianceTransformationMatrixPool = createDummyMatrixPool("varianceTransformationMatrix"); varianceTransformationVectorPool = createDummyVectorPool("varianceTransformationMatrix"); String modelName = ps.getString(MODEL_NAME); String location = ps.getString(LOCATION); String phone = ps.getString(PHONE_LIST); String dataDir = ps.getString(DATA_DIR); logger.info("Creating Sphinx3 acoustic model: " + modelName); logger.info(" Path : " + location); logger.info(" phonelist : " + phone); logger.info(" dataDir : " + dataDir); // load the HMM model file boolean useCDUnits = ps.getBoolean(PROP_USE_CD_UNITS); assert !useCDUnits; try { loadPhoneList(ps, useCDUnits, StreamFactory.getInputStream(location, phone), location + File.separator + phone); } catch (StreamCorruptedException sce) { printPhoneListHelp(); } catch (IOException e) { e.printStackTrace(); } } /** Prints out a help message with format of phone list. */ private void printPhoneListHelp() { System.out.println("The format for the phone list file is:"); System.out.println("\tversion 0.1"); System.out.println("\tsame_sized_models yes"); System.out.println("\tn_state 3"); System.out.println("\ttmat_skip (no|yes)"); System.out.println("\tAA"); System.out.println("\tAE"); System.out.println("\tAH"); System.out.println("\t..."); System.out.println("Or:"); System.out.println("\tversion 0.1"); System.out.println("\tsame_sized_models no"); System.out.println("\ttmat_skip (no|yes)"); System.out.println("\tAA 5"); System.out.println("\tAE 3"); System.out.println("\tAH 4"); System.out.println("\t..."); } public Map<String, Unit> getContextIndependentUnits() { return contextIndependentUnits; } /** * Adds a model to the senone pool. * * @param pool the senone pool * @param stateID vector with senone ID for an HMM * @param distFloor the lowest allowed score * @param varianceFloor the lowest allowed variance * @return the senone pool */ private void addModelToSenonePool(Pool<Senone> pool, int[] stateID, float distFloor, float varianceFloor) { assert pool != null; // int numMixtureWeights = mixtureWeightsPool.size(); /* int numMeans = meansPool.size(); int numVariances = variancePool.size(); int numSenones = mixtureWeightsPool.getFeature(NUM_SENONES, 0); int whichGaussian = 0; logger.fine("NG " + numGaussiansPerSenone); logger.fine("NS " + numSenones); logger.fine("NMIX " + numMixtureWeights); logger.fine("NMNS " + numMeans); logger.fine("NMNS " + numVariances); assert numMixtureWeights == numSenones; assert numVariances == numSenones * numGaussiansPerSenone; assert numMeans == numSenones * numGaussiansPerSenone; */ int numGaussiansPerSenone = mixtureWeights.getGauPerState(); assert numGaussiansPerSenone > 0; for (int state : stateID) { MixtureComponent[] mixtureComponents = new MixtureComponent[numGaussiansPerSenone]; for (int j = 0; j < numGaussiansPerSenone; j++) { int whichGaussian = state * numGaussiansPerSenone + j; mixtureComponents[j] = new MixtureComponent( meansPool.get(whichGaussian), meanTransformationMatrixPool.get(0), meanTransformationVectorPool.get(0), variancePool.get(whichGaussian), varianceTransformationMatrixPool.get(0), varianceTransformationVectorPool.get(0), distFloor, varianceFloor); } Senone senone = new GaussianMixture(mixtureWeights, mixtureComponents, state); pool.put(state, senone); } } /** * Adds a set of density arrays to a given pool. * * @param pool the pool to add densities to * @param stateID a vector with the senone id of the states in a model * @param numStreams the number of streams * @param numGaussiansPerState the number of Gaussians per state * @throws IOException if an error occurs while loading the data */ private void addModelToDensityPool(Pool<float[]> pool, int[] stateID, int numStreams, int numGaussiansPerState) throws IOException { assert pool != null; assert stateID != null; int numStates = stateID.length; int numInPool = pool.getFeature(NUM_SENONES, 0); pool.setFeature(NUM_SENONES, numStates + numInPool); numInPool = pool.getFeature(NUM_STREAMS, -1); if (numInPool == -1) { pool.setFeature(NUM_STREAMS, numStreams); } else { assert numInPool == numStreams; } numInPool = pool.getFeature(NUM_GAUSSIANS_PER_STATE, -1); if (numInPool == -1) { pool.setFeature(NUM_GAUSSIANS_PER_STATE, numGaussiansPerState); } else { assert numInPool == numGaussiansPerState; } // TODO: numStreams should be any number > 0, but for now.... assert numStreams == 1; for (int i = 0; i < numStates; i++) { int state = stateID[i]; for (int j = 0; j < numGaussiansPerState; j++) { // We're creating densities here, so it's ok if values // are all zero. float[] density = new float[vectorLength]; int id = state * numGaussiansPerState + j; pool.put(id, density); } } } /** * If a data point is below 'floor' make it equal to floor. * * @param data the data to floor * @param floor the floored value */ private void floorData(float[] data, float floor) { for (int i = 0; i < data.length; i++) { if (data[i] < floor) { data[i] = floor; } } } /** * Normalize the given data. * * @param data the data to normalize */ private void normalize(float[] data) { float sum = 0; for (float val : data) { sum += val; } if (sum != 0.0f) { // Invert, so we multiply instead of dividing inside the loop sum = 1.0f / sum; for (int i = 0; i < data.length; i++) { data[i] = data[i] * sum; } } } /** * Loads the phone list, which possibly contains the sizes (number of states) of models. * * @param ps * @param useCDUnits if true, uses context dependent units * @param inputStream the open input stream to use * @param path the path to a density file @throws FileNotFoundException if a file cannot be found * @throws IOException if an error occurs while loading the data */ private void loadPhoneList(PropertySheet ps, boolean useCDUnits, InputStream inputStream, String path) throws IOException { int numState = 0; // TODO: this should be flexible, but we're hardwiring for now int numStreams = 1; // Since we're initializing, we start simple. int numGaussiansPerState = 1; ExtendedStreamTokenizer est = new ExtendedStreamTokenizer(inputStream, '#', false); // Initialize the pools we'll need. meansPool = new Pool<float[]>("means"); variancePool = new Pool<float[]>("variances"); matrixPool = new Pool<float[][]>("transitionmatrices"); senonePool = new Pool<Senone>("senones"); float distFloor = ps.getFloat(PROP_MC_FLOOR); float mixtureWeightFloor = ps.getFloat(PROP_MW_FLOOR); float transitionProbabilityFloor = 0; float varianceFloor = ps.getFloat(PROP_VARIANCE_FLOOR); logger.info("Loading phone list file from: "); logger.info(path); // At this point, we only accept version 0.1 String version = "0.1"; est.expectString("version"); est.expectString(version); est.expectString("same_sized_models"); boolean sameSizedModels = est.getString().equals("yes"); if (sameSizedModels) { est.expectString("n_state"); numState = est.getInt("numBase"); } // for this phone list version, let's assume left-to-right // models, with optional state skip. est.expectString("tmat_skip"); boolean tmatSkip = est.getString().equals("yes"); // Load the phones with sizes // stateIndex contains the absolute state index, that is, a // unique index in the senone pool. int stateIndex, unitCount; for (stateIndex = 0, unitCount = 0; ;) { String phone = est.getString(); if (est.isEOF()) { break; } int size = numState; if (!sameSizedModels) { size = est.getInt("ModelSize"); } phoneList.put(phone, size); logger.fine("Phone: " + phone + " size: " + size); int[] stid = new int[size]; String position = "-"; for (int j = 0; j < size; j++, stateIndex++) { stid[j] = stateIndex; } Unit unit = unitManager.getUnit(phone, phone.equals(SILENCE_CIPHONE)); contextIndependentUnits.put(unit.getName(), unit); if (logger.isLoggable(Level.FINE)) { logger.fine("Loaded " + unit + " with " + size + " states"); } // Means addModelToDensityPool(meansPool, stid, numStreams, numGaussiansPerState); // Variances addModelToDensityPool(variancePool, stid, numStreams, numGaussiansPerState); // Transition matrix addModelToTransitionMatrixPool(matrixPool, unitCount, stid.length, transitionProbabilityFloor, tmatSkip); // After creating all pools, we create the senone pool. addModelToSenonePool(senonePool, stid, distFloor, varianceFloor); // With the senone pool in place, we go through all units, and // create the HMMs. // Create tmat float[][] transitionMatrix = matrixPool.get(unitCount); SenoneSequence ss = getSenoneSequence(stid); HMM hmm = new SenoneHMM(unit, ss, transitionMatrix, HMMPosition.lookup(position)); hmmManager.put(hmm); unitCount++; } // Mixture weights - all at once mixtureWeights = initMixtureWeights(stateIndex, numStreams, numGaussiansPerState, mixtureWeightFloor); // If we want to use this code to load sizes/create models for // CD units, we need to find another way of establishing the // number of CI models, instead of just reading until the end // of file. est.close(); } /** * Gets the senone sequence representing the given senones. * * @param stateid is the array of senone state ids * @return the senone sequence associated with the states */ private SenoneSequence getSenoneSequence(int[] stateid) { Senone[] senones = new Senone[stateid.length]; for (int i = 0; i < stateid.length; i++) { senones[i] = senonePool.get(stateid[i]); } // TODO: Is there any advantage in trying to pool these? return new SenoneSequence(senones); } /** * Adds model to the mixture weights * * @param numStates the number of states * @param numStreams the number of streams * @param numGaussiansPerState the number of Gaussians per state * @param floor the minimum mixture weight allowed * @return mixtureWeights the gaussian weights holder */ private GaussianWeights initMixtureWeights(int numStates, int numStreams, int numGaussiansPerState, float floor) { // TODO: allow any number for numStreams assert numStreams == 1; GaussianWeights mixtureWeights = new GaussianWeights("mixtureweights", numStates, numGaussiansPerState, numStreams); for (int i = 0; i < numStates; i++) { float[] logMixtureWeight = new float[numGaussiansPerState]; // Initialize the weights with the same value, e.g. floor floorData(logMixtureWeight, floor); // Normalize, so the numbers are not all too low normalize(logMixtureWeight); logMath.linearToLog(logMixtureWeight); mixtureWeights.put(i, 0, logMixtureWeight); } return mixtureWeights; } /** * Adds transition matrix to the transition matrices pool * * @param pool the pool to add matrix to * @param hmmId current HMM's id * @param numEmittingStates number of states in current HMM * @param floor the transition probability floor * @param skip if true, states can be skipped * @throws IOException if an error occurs while loading the data */ private void addModelToTransitionMatrixPool(Pool<float[][]> pool, int hmmId, int numEmittingStates, float floor, boolean skip) throws IOException { assert pool != null; // Add one to account for the last, non-emitting, state int numStates = numEmittingStates + 1; float[][] tmat = new float[numStates][numStates]; for (int j = 0; j < numStates; j++) { for (int k = 0; k < numStates; k++) { // Just to be sure... tmat[j][k] = 0.0f; // the last row is just zeros, so we just do // the first (numStates - 1) rows // The value assigned could be anything, provided // we normalize it. if (j < numStates - 1) { // Usual case: state can transition to itself // or the next state. if (k == j || k == j + 1) { tmat[j][k] = floor; } // If we can skip, we can also transition to // the next state if (skip) { if (k == j + 2) { tmat[j][k] = floor; } } } } normalize(tmat[j]); logMath.linearToLog(tmat[j]); } pool.put(hmmId, tmat); } /** * Creates a pool with a single identity matrix in it. * * @param name the name of the pool * @return the pool with the matrix */ private Pool<float[][]> createDummyMatrixPool(String name) { Pool<float[][]> pool = new Pool<float[][]>(name); float[][] matrix = new float[vectorLength][vectorLength]; logger.info("creating dummy matrix pool " + name); for (int i = 0; i < vectorLength; i++) { for (int j = 0; j < vectorLength; j++) { if (i == j) { matrix[i][j] = 1.0F; } else { matrix[i][j] = 0.0F; } } } pool.put(0, matrix); return pool; } /** * Creates a pool with a single zero vector in it. * * @param name the name of the pool * @return the pool with the vector */ private Pool<float[]> createDummyVectorPool(String name) { logger.info("creating dummy vector pool " + name); Pool<float[]> pool = new Pool<float[]>(name); float[] vector = new float[vectorLength]; for (int i = 0; i < vectorLength; i++) { vector[i] = 0.0f; } pool.put(0, vector); return pool; } public void load() throws IOException { } public Pool<float[]> getMeansPool() { return meansPool; } public Pool<float[][]> getMeansTransformationMatrixPool() { return meanTransformationMatrixPool; } public Pool<float[]> getMeansTransformationVectorPool() { return meanTransformationVectorPool; } public Pool<float[]> getVariancePool() { return variancePool; } public Pool<float[][]> getVarianceTransformationMatrixPool() { return varianceTransformationMatrixPool; } public Pool<float[]> getVarianceTransformationVectorPool() { return varianceTransformationVectorPool; } public GaussianWeights getMixtureWeights() { return mixtureWeights; } public Pool<float[][]> getTransitionMatrixPool() { return matrixPool; } public float[][] getTransformMatrix() { return null; } public Pool<Senone> getSenonePool() { return senonePool; } public int getLeftContextSize() { return CONTEXT_SIZE; } public int getRightContextSize() { return CONTEXT_SIZE; } public HMMManager getHMMManager() { return hmmManager; } public void logInfo() { logger.info("Sphinx3Loader"); meansPool.logInfo(logger); variancePool.logInfo(logger); matrixPool.logInfo(logger); senonePool.logInfo(logger); meanTransformationMatrixPool.logInfo(logger); meanTransformationVectorPool.logInfo(logger); varianceTransformationMatrixPool.logInfo(logger); varianceTransformationVectorPool.logInfo(logger); mixtureWeights.logInfo(logger); senonePool.logInfo(logger); logger.info("Context Independent Unit Entries: " + contextIndependentUnits.size()); hmmManager.logInfo(logger); } public Properties getProperties() { return new Properties(); } public void update(Transform transform, ClusteredDensityFileData clusters) { // TODO Not implemented yet } }