/* * Copyright 1999-2004 Carnegie Mellon University. * Portions Copyright 2004 Sun Microsystems, Inc. * Portions Copyright 2004 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.linguist.acoustic.tiedstate; import edu.cmu.sphinx.decoder.adaptation.ClusteredDensityFileData; import edu.cmu.sphinx.decoder.adaptation.Transform; import edu.cmu.sphinx.linguist.acoustic.*; import edu.cmu.sphinx.linguist.acoustic.tiedstate.tiedmixture.MixtureComponentSet; import edu.cmu.sphinx.linguist.acoustic.tiedstate.tiedmixture.PrunableMixtureComponent; import edu.cmu.sphinx.linguist.acoustic.tiedstate.tiedmixture.SetBasedGaussianMixture; import static edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool.Feature.*; import edu.cmu.sphinx.util.ExtendedStreamTokenizer; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.TimerPool; import edu.cmu.sphinx.util.Utilities; import edu.cmu.sphinx.util.props.*; import java.io.*; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.net.URL; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.Map; import java.util.Properties; import java.util.logging.Level; import java.util.logging.Logger; /** * Loads a tied-state acoustic model generated by the Sphinx-3 trainer. * <p> * The acoustic model is stored as a directory specified by a URL. The * dictionary and language model files are not required to be in the package. * You can specify their locations separately. * <p> * Configuration file should set mandatory property of component: <b>location</b> - * this specifies the directory where the actual model * data files are. You can use <b>resource:</b> prefix to refer to files packed * inside jar or any other URI scheme. * The actual model data files are named "mdef", "means", "variances", * "transition_matrices", "mixture_weights". */ public class Sphinx3Loader implements Loader { /** * The unit manager */ @S4Component(type = UnitManager.class) public final static String PROP_UNIT_MANAGER = "unitManager"; /** * The root location of the model directory structure */ @S4String(mandatory = true) public final static String PROP_LOCATION = "location"; /** * The property specifying whether context-dependent units should be used. */ @S4Boolean(defaultValue = true) public final static String PROP_USE_CD_UNITS = "useCDUnits"; /** * Mixture component score floor. */ @S4Double(defaultValue = 0.0f) public final static String PROP_MC_FLOOR = "mixtureComponentScoreFloor"; /** * Variance floor. */ @S4Double(defaultValue = 0.0001f) public final static String PROP_VARIANCE_FLOOR = "varianceFloor"; /** * Mixture weight floor */ @S4Double(defaultValue = 1e-7f) public final static String PROP_MW_FLOOR = "mixtureWeightFloor"; /** * Number of top Gaussians to use in scoring */ @S4Integer(defaultValue = 4) public final static String PROP_TOPN = "topGaussiansNum"; protected final static String FILLER = "filler"; protected final static String SILENCE_CIPHONE = "SIL"; protected final static int BYTE_ORDER_MAGIC = 0x11223344; /** * Supports this version of the acoustic model */ public final static String MODEL_VERSION = "0.3"; private final static int CONTEXT_SIZE = 1; protected Properties modelProps; protected Pool<float[]> meansPool; protected Pool<float[]> variancePool; protected Pool<float[][]> transitionsPool; protected GaussianWeights mixtureWeights; private int numStates; private int numStreams; private int numBase; private int numGaussiansPerState; private int[] vectorLength; private int[] senone2ci; protected Pool<float[][]> meanTransformationMatrixPool; protected Pool<float[]> meanTransformationVectorPool; protected Pool<float[][]> varianceTransformationMatrixPool; protected Pool<float[]> varianceTransformationVectorPool; protected float[][] transformMatrix; private MixtureComponentSet[] phoneticTiedMixtures; protected Pool<Senone> senonePool; private Map<String, Unit> contextIndependentUnits; private HMMManager hmmManager; protected LogMath logMath; private UnitManager unitManager; private boolean swap; private final static String DENSITY_FILE_VERSION = "1.0"; private final static String MIXW_FILE_VERSION = "1.0"; private final static String TMAT_FILE_VERSION = "1.0"; private final static String TRANSFORM_FILE_VERSION = "0.1"; // -------------------------------------- // Configuration variables // -------------------------------------- protected Logger logger; private URL location; protected float distFloor; protected float mixtureWeightFloor; protected float varianceFloor; private int topGauNum; protected boolean useCDUnits; private boolean loaded; public Sphinx3Loader(URL location, UnitManager unitManager, float distFloor, float mixtureWeightFloor, float varianceFloor, int topGauNum, boolean useCDUnits) { init(location, unitManager, distFloor, mixtureWeightFloor, varianceFloor, topGauNum, useCDUnits, Logger.getLogger(getClass().getName())); } public Sphinx3Loader(String location, UnitManager unitManager, float distFloor, float mixtureWeightFloor, float varianceFloor, int topGauNum, boolean useCDUnits) throws MalformedURLException, ClassNotFoundException { init(ConfigurationManagerUtils.resourceToURL(location), unitManager, distFloor, mixtureWeightFloor, varianceFloor, topGauNum, useCDUnits, Logger.getLogger(getClass().getName())); } protected void init(URL location, UnitManager unitManager, float distFloor, float mixtureWeightFloor, float varianceFloor, int topGauNum, boolean useCDUnits, Logger logger) { logMath = LogMath.getLogMath(); this.location = location; this.logger = logger; this.unitManager = unitManager; this.distFloor = distFloor; this.mixtureWeightFloor = mixtureWeightFloor; this.varianceFloor = varianceFloor; this.topGauNum = topGauNum; this.useCDUnits = useCDUnits; } public Sphinx3Loader() { } public int getNumStates() { return numStates; } public int getNumStreams() { return numStreams; } public int getNumGaussiansPerState() { return numGaussiansPerState; } public int[] getVectorLength() { return vectorLength; } public int[] getSenone2Ci() { return senone2ci; } public String getLocation() { return this.location.getPath(); } public boolean hasTiedMixtures() { String modelType = modelProps.getProperty("-model", "cont"); return modelType.equals("ptm"); } public void newProperties(PropertySheet ps) throws PropertyException { init(ConfigurationManagerUtils.getResource(PROP_LOCATION, ps), (UnitManager) ps.getComponent(PROP_UNIT_MANAGER), ps.getFloat(PROP_MC_FLOOR), ps.getFloat(PROP_MW_FLOOR), ps.getFloat(PROP_VARIANCE_FLOOR), ps.getInt(PROP_TOPN), ps.getBoolean(PROP_USE_CD_UNITS), ps.getLogger()); } // This function is a bit different from the // ConfigurationManagerUtils.getResource // for compatibility reasons. By default it looks for the resources, not // for the files. protected InputStream getDataStream(String path) throws IOException, URISyntaxException { return new URL(Utilities.pathJoin(location.toString(), path)).openStream(); } public void load() throws IOException { if (!loaded) { TimerPool.getTimer(this, "Load AM").start(); hmmManager = new HMMManager(); contextIndependentUnits = new LinkedHashMap<String, Unit>(); // dummy pools for these elements meanTransformationMatrixPool = null; meanTransformationVectorPool = null; varianceTransformationMatrixPool = null; varianceTransformationVectorPool = null; transformMatrix = null; // do the actual acoustic model loading try { loadModelFiles(); } catch (URISyntaxException e) { throw new RuntimeException(e); } // done loaded = true; TimerPool.getTimer(this, "Load AM").stop(); } } /** * Return the HmmManager. * * @return the hmmManager */ protected HMMManager getHmmManager() { return hmmManager; } /** * Return the MatrixPool. * * @return the matrixPool */ protected Pool<float[][]> getMatrixPool() { return transitionsPool; } /** * Return the MixtureWeightsPool. * * @return the mixtureWeightsPool */ protected GaussianWeights getMixtureWeightsPool() { return mixtureWeights; } /** * Loads the AcousticModel from a directory in the file system. * @throws IOException IO went wrong * @throws URISyntaxException uri was incorrectly specified */ protected void loadModelFiles() throws IOException, URISyntaxException { meansPool = loadDensityFile("means", -Float.MAX_VALUE); variancePool = loadDensityFile("variances", varianceFloor); mixtureWeights = loadMixtureWeights("mixture_weights", mixtureWeightFloor); transitionsPool = loadTransitionMatrices("transition_matrices"); transformMatrix = loadTransformMatrix("feature_transform"); modelProps = loadModelProps("feat.params"); if (hasTiedMixtures()) { //create senone to CI mapping getSenoneToCIPhone(); //create tied senone pool senonePool = createTiedSenonePool(distFloor, varianceFloor); } else { //create regular senone poll senonePool = createSenonePool(distFloor, varianceFloor); } // load the HMM modelDef file InputStream modelStream = getDataStream("mdef"); if (modelStream == null) { throw new IOException("can't find model definition"); } loadHMMPool(useCDUnits, modelStream); } public Map<String, Unit> getContextIndependentUnits() { return contextIndependentUnits; } /** * Creates senone to CI phone mapping, reading model definition file */ private void getSenoneToCIPhone() throws IOException, URISyntaxException { InputStream inputStream = getDataStream("mdef"); if (inputStream == null) { throw new IOException("can't find model definition"); } ExtendedStreamTokenizer est = new ExtendedStreamTokenizer(inputStream, '#', false); logger.fine("Loading HMM file from " + location); est.expectString(MODEL_VERSION); numBase = est.getInt("numBase"); est.expectString("n_base"); int numTri = est.getInt("numTri"); est.expectString("n_tri"); int numStateMap = est.getInt("numStateMap"); est.expectString("n_state_map"); int numTiedState = est.getInt("numTiedState"); est.expectString("n_tied_state"); senone2ci = new int[numTiedState]; est.getInt("numContextIndependentTiedState"); est.expectString("n_tied_ci_state"); int numTiedTransitionMatrices = est.getInt("numTiedTransitionMatrices"); est.expectString("n_tied_tmat"); int numStatePerHMM = numStateMap / (numTri + numBase); assert numTiedState == mixtureWeights.getStatesNum(); assert numTiedTransitionMatrices == transitionsPool.size(); // Load the base phones for (int i = 0; i < numBase + numTri; i++) { //TODO name this magic const somehow for (int j = 0; j < 5; j++) est.getString(); int tmat = est.getInt("tmat"); for (int j = 0; j < numStatePerHMM - 1; j++) { senone2ci[est.getInt("j")] = tmat; } est.expectString("N"); assert tmat < numTiedTransitionMatrices; } est.close(); } /** * Creates the senone pool from the rest of the pools. * * @param distFloor * the lowest allowed score * @param varianceFloor * the lowest allowed variance * @return the senone pool */ protected Pool<Senone> createSenonePool(float distFloor, float varianceFloor) { Pool<Senone> pool = new Pool<Senone>("senones"); int numMeans = meansPool.size(); int numVariances = variancePool.size(); int numGaussiansPerSenone = mixtureWeights.getGauPerState(); int numSenones = mixtureWeights.getStatesNum(); int numStreams = mixtureWeights.getStreamsNum(); int whichGaussian = 0; logger.fine("Senones " + numSenones); logger.fine("Gaussians Per Senone " + numGaussiansPerSenone); logger.fine("Means " + numMeans); logger.fine("Variances " + numVariances); assert numGaussiansPerSenone > 0; assert numVariances == numSenones * numGaussiansPerSenone; assert numMeans == numSenones * numGaussiansPerSenone; float[][] meansTransformationMatrix = meanTransformationMatrixPool == null ? null : meanTransformationMatrixPool.get(0); float[] meansTransformationVector = meanTransformationVectorPool == null ? null : meanTransformationVectorPool.get(0); float[][] varianceTransformationMatrix = varianceTransformationMatrixPool == null ? null : varianceTransformationMatrixPool.get(0); float[] varianceTransformationVector = varianceTransformationVectorPool == null ? null : varianceTransformationVectorPool.get(0); for (int i = 0; i < numSenones; i++) { MixtureComponent[] mixtureComponents = new MixtureComponent[numGaussiansPerSenone * numStreams]; for (int j = 0; j < numGaussiansPerSenone; j++) { mixtureComponents[j] = new MixtureComponent( meansPool.get(whichGaussian), meansTransformationMatrix, meansTransformationVector, variancePool.get(whichGaussian), varianceTransformationMatrix, varianceTransformationVector, distFloor, varianceFloor); whichGaussian++; } Senone senone = new GaussianMixture(mixtureWeights, mixtureComponents, i); pool.put(i, senone); } return pool; } /** * Creates the tied senone pool from the rest of the pools. * * @param distFloor * the lowest allowed score * @param varianceFloor * the lowest allowed variance * @return the senone pool */ private Pool<Senone> createTiedSenonePool(float distFloor, float varianceFloor) { Pool<Senone> pool = new Pool<Senone>("senones"); int numMeans = meansPool.size(); int numVariances = variancePool.size(); int numGaussiansPerState = mixtureWeights.getGauPerState(); int numSenones = mixtureWeights.getStatesNum(); int numStreams = mixtureWeights.getStreamsNum(); logger.fine("Senones " + numSenones); logger.fine("Gaussians Per State " + numGaussiansPerState); logger.fine("Means " + numMeans); logger.fine("Variances " + numVariances); assert numGaussiansPerState > 0; assert numVariances == numBase * numGaussiansPerState * numStreams; assert numMeans == numBase * numGaussiansPerState * numStreams; float[][] meansTransformationMatrix = meanTransformationMatrixPool == null ? null : meanTransformationMatrixPool.get(0); float[] meansTransformationVector = meanTransformationVectorPool == null ? null : meanTransformationVectorPool.get(0); float[][] varianceTransformationMatrix = varianceTransformationMatrixPool == null ? null : varianceTransformationMatrixPool.get(0); float[] varianceTransformationVector = varianceTransformationVectorPool == null ? null : varianceTransformationVectorPool.get(0); phoneticTiedMixtures = new MixtureComponentSet[numBase]; for (int i = 0; i < numBase; i++) { ArrayList<PrunableMixtureComponent[]> mixtureComponents = new ArrayList<PrunableMixtureComponent[]>(); for (int j = 0; j < numStreams; j++) { PrunableMixtureComponent[] featMixtureComponents = new PrunableMixtureComponent[numGaussiansPerState]; for (int k = 0; k < numGaussiansPerState; k++) { int whichGaussian = i * numGaussiansPerState * numStreams + j * numGaussiansPerState + k; featMixtureComponents[k] = new PrunableMixtureComponent( meansPool.get(whichGaussian), meansTransformationMatrix, meansTransformationVector, variancePool.get(whichGaussian), varianceTransformationMatrix, varianceTransformationVector, distFloor, varianceFloor, k); } mixtureComponents.add(featMixtureComponents); } phoneticTiedMixtures[i] = new MixtureComponentSet(mixtureComponents, topGauNum); } for (int i = 0; i < numSenones; i++) { Senone senone = new SetBasedGaussianMixture(mixtureWeights, phoneticTiedMixtures[senone2ci[i]], i); pool.put(i, senone); } return pool; } /** * Loads the sphinx3 density file, a set of density arrays are created and * placed in the given pool. * * @param path * the name of the data * @param floor * the minimum density allowed * @return a pool of loaded densities * @throws FileNotFoundException * if a file cannot be found * @throws IOException * if an error occurs while loading the data * @throws URISyntaxException uri was incorrectly specified */ public Pool<float[]> loadDensityFile(String path, float floor) throws IOException, URISyntaxException { Properties props = new Properties(); int blockSize = 0; DataInputStream dis = readS3BinaryHeader(path, props); String version = props.getProperty("version"); if (version == null || !version.equals(DENSITY_FILE_VERSION)) { throw new IOException("Unsupported version in " + path); } String checksum = props.getProperty("chksum0"); boolean doCheckSum = (checksum != null && checksum.equals("yes")); resetChecksum(); int numStates = readInt(dis); int numStreams = readInt(dis); int numGaussiansPerState = readInt(dis); int[] vectorLength = new int[numStreams]; for (int i = 0; i < numStreams; i++) { vectorLength[i] = readInt(dis); } int rawLength = readInt(dis); logger.fine("Number of states " + numStates); logger.fine("Number of streams " + numStreams); logger.fine("Number of gaussians per state " + numGaussiansPerState); logger.fine("Vector length " + vectorLength.length); logger.fine("Raw length " + rawLength); for (int i = 0; i < numStreams; i++) { blockSize += vectorLength[i]; } assert rawLength == numGaussiansPerState * blockSize * numStates; Pool<float[]> pool = new Pool<float[]>(path); pool.setFeature(NUM_SENONES, numStates); pool.setFeature(NUM_STREAMS, numStreams); pool.setFeature(NUM_GAUSSIANS_PER_STATE, numGaussiansPerState); for (int i = 0; i < numStates; i++) { for (int j = 0; j < numStreams; j++) { for (int k = 0; k < numGaussiansPerState; k++) { float[] density = readFloatArray(dis, vectorLength[j]); Utilities.floorData(density, floor); pool.put(i * numStreams * numGaussiansPerState + j * numGaussiansPerState + k, density); } } } validateChecksum(dis, doCheckSum); dis.close(); this.numStates = numStates; this.numStreams = numStreams; this.numGaussiansPerState = numGaussiansPerState; this.vectorLength = vectorLength; return pool; } /** * Reads the S3 binary header from the given location + path. Adds header * information to the given set of properties. * * @param path * the name of the file * @param props * the properties * @return the input stream positioned after the header * @throws IOException * on error * @throws URISyntaxException uri was incorrectly specified */ public DataInputStream readS3BinaryHeader(String path, Properties props) throws IOException, URISyntaxException { InputStream inputStream = getDataStream(path); if (inputStream == null) { throw new IOException("Can't open " + path); } DataInputStream dis = new DataInputStream(new BufferedInputStream( inputStream)); String id = readWord(dis); if (!id.equals("s3")) { throw new IOException("Not proper s3 binary file " + path); } String name; while ((name = readWord(dis)) != null) { if (!name.equals("endhdr")) { String value = readWord(dis); props.setProperty(name, value); } else { break; } } int byteOrderMagic = dis.readInt(); if (byteOrderMagic == BYTE_ORDER_MAGIC) { logger.fine("Not swapping " + path); swap = false; } else if (Utilities.swapInteger(byteOrderMagic) == BYTE_ORDER_MAGIC) { logger.fine("Swapping " + path); swap = true; } else { throw new IOException("Corrupted S3 file " + path); } return dis; } /** * Reads the next word (text separated by whitespace) from the given stream. * * @param dis * the input stream * @return the next word * @throws IOException * on error */ String readWord(DataInputStream dis) throws IOException { StringBuilder sb = new StringBuilder(); char c; // skip leading whitespace do { c = readChar(dis); } while (Character.isWhitespace(c)); // read the word do { sb.append(c); c = readChar(dis); } while (!Character.isWhitespace(c)); return sb.toString(); } /** * Reads a single char from the stream. * * @param dis * the stream to read * @return the next character on the stream * @throws IOException * if an error occurs */ private char readChar(DataInputStream dis) throws IOException { return (char) dis.readByte(); } /* Stores checksum during loading */ private long calculatedCheckSum = 0; /** * Resets the checksum before loading a new chunk of data */ private void resetChecksum() { calculatedCheckSum = 0; } /** * Validates checksum in the stream * * @param dis * input stream * @param doCheckSum * validates * @throws IOException * on error **/ private void validateChecksum(DataInputStream dis, boolean doCheckSum) throws IOException { if (!doCheckSum) return; int oldCheckSum = (int) calculatedCheckSum; int checkSum = readInt(dis); if (checkSum != oldCheckSum) { throw new IOException("Invalid checksum " + Long.toHexString(calculatedCheckSum) + " must be " + Integer.toHexString(checkSum)); } } /** * Read an integer from the input stream, byte-swapping as necessary. * * @param dis * the input stream * @return an integer value * @throws IOException * on error */ public int readInt(DataInputStream dis) throws IOException { int val; if (swap) { val = Utilities.readLittleEndianInt(dis); } else { val = dis.readInt(); } calculatedCheckSum = ((calculatedCheckSum << 20 | calculatedCheckSum >> 12) + val) & 0xFFFFFFFFL; return val; } /** * Read a float from the input stream, byte-swapping as necessary. * * @param dis * the input stream * @return a floating pint value * @throws IOException * on error */ public float readFloat(DataInputStream dis) throws IOException { int val; if (swap) { val = Utilities.readLittleEndianInt(dis); } else { val = dis.readInt(); } calculatedCheckSum = ((calculatedCheckSum << 20 | calculatedCheckSum >> 12) + val) & 0xFFFFFFFFL; return Float.intBitsToFloat(val); } /** * Reads the given number of floats from the stream and returns them in an * array of floats. * * @param dis * the stream to read data from * @param size * the number of floats to read * @return an array of size float elements * @throws IOException * if an exception occurs */ public float[] readFloatArray(DataInputStream dis, int size) throws IOException { float[] data = new float[size]; for (int i = 0; i < size; i++) { data[i] = readFloat(dis); } return data; } /** * Loads the sphinx3 density file, a set of density arrays are created and * placed in the given pool. * * @param useCDUnits * if true, loads also the context dependent units * @param inputStream * the open input stream to use * @throws FileNotFoundException * if a file cannot be found * @throws IOException * if an error occurs while loading the data */ protected void loadHMMPool(boolean useCDUnits, InputStream inputStream) throws IOException { ExtendedStreamTokenizer est = new ExtendedStreamTokenizer(inputStream, '#', false); logger.fine("Loading HMM file from: " + location); est.expectString(MODEL_VERSION); int numBase = est.getInt("numBase"); est.expectString("n_base"); int numTri = est.getInt("numTri"); est.expectString("n_tri"); int numStateMap = est.getInt("numStateMap"); est.expectString("n_state_map"); int numTiedState = est.getInt("numTiedState"); est.expectString("n_tied_state"); int numContextIndependentTiedState = est .getInt("numContextIndependentTiedState"); est.expectString("n_tied_ci_state"); int numTiedTransitionMatrices = est.getInt("numTiedTransitionMatrices"); est.expectString("n_tied_tmat"); int numStatePerHMM = numStateMap / (numTri + numBase); assert numTiedState == mixtureWeights.getStatesNum(); assert numTiedTransitionMatrices == transitionsPool.size(); // Load the base phones for (int i = 0; i < numBase; i++) { String name = est.getString(); String left = est.getString(); String right = est.getString(); String position = est.getString(); String attribute = est.getString(); int tmat = est.getInt("tmat"); int[] stid = new int[numStatePerHMM - 1]; for (int j = 0; j < numStatePerHMM - 1; j++) { stid[j] = est.getInt("j"); assert stid[j] >= 0 && stid[j] < numContextIndependentTiedState; } est.expectString("N"); assert left.equals("-"); assert right.equals("-"); assert position.equals("-"); assert tmat < numTiedTransitionMatrices; Unit unit = unitManager.getUnit(name, attribute.equals(FILLER)); contextIndependentUnits.put(unit.getName(), unit); if (logger.isLoggable(Level.FINE)) { logger.fine("Loaded " + unit); } // The first filler if (unit.isFiller() && unit.getName().equals(SILENCE_CIPHONE)) { unit = UnitManager.SILENCE; } float[][] transitionMatrix = transitionsPool.get(tmat); SenoneSequence ss = getSenoneSequence(stid); HMM hmm = new SenoneHMM(unit, ss, transitionMatrix, HMMPosition.lookup(position)); hmmManager.put(hmm); } if (hmmManager.get(HMMPosition.UNDEFINED, UnitManager.SILENCE) == null) { throw new IOException("Could not find SIL unit in acoustic model"); } // Load the context dependent phones. If the useCDUnits // property is false, the CD phones will not be created, but // the values still need to be read in from the file. String lastUnitName = ""; Unit lastUnit = null; int[] lastStid = null; SenoneSequence lastSenoneSequence = null; for (int i = 0; i < numTri; i++) { String name = est.getString(); String left = est.getString(); String right = est.getString(); String position = est.getString(); String attribute = est.getString(); int tmat = est.getInt("tmat"); int[] stid = new int[numStatePerHMM - 1]; for (int j = 0; j < numStatePerHMM - 1; j++) { stid[j] = est.getInt("j"); assert stid[j] >= numContextIndependentTiedState && stid[j] < numTiedState; } est.expectString("N"); assert !left.equals("-"); assert !right.equals("-"); assert !position.equals("-"); assert attribute.equals("n/a"); assert tmat < numTiedTransitionMatrices; if (useCDUnits) { Unit unit; String unitName = (name + ' ' + left + ' ' + right); if (unitName.equals(lastUnitName)) { unit = lastUnit; } else { Unit[] leftContext = new Unit[1]; leftContext[0] = contextIndependentUnits.get(left); Unit[] rightContext = new Unit[1]; rightContext[0] = contextIndependentUnits.get(right); Context context = LeftRightContext.get(leftContext, rightContext); unit = unitManager.getUnit(name, false, context); } lastUnitName = unitName; lastUnit = unit; if (logger.isLoggable(Level.FINE)) { logger.fine("Loaded " + unit); } float[][] transitionMatrix = transitionsPool.get(tmat); SenoneSequence ss = lastSenoneSequence; if (ss == null || !sameSenoneSequence(stid, lastStid)) { ss = getSenoneSequence(stid); } lastSenoneSequence = ss; lastStid = stid; HMM hmm = new SenoneHMM(unit, ss, transitionMatrix, HMMPosition.lookup(position)); hmmManager.put(hmm); } } est.close(); } /** * Returns true if the given senone sequence IDs are the same. * * @param ssid1 ids of first senone sequence * @param ssid2 ids of second senone sequence * @return true if the given senone sequence IDs are the same, false * otherwise */ protected boolean sameSenoneSequence(int[] ssid1, int[] ssid2) { if (ssid1.length == ssid2.length) { for (int i = 0; i < ssid1.length; i++) { if (ssid1[i] != ssid2[i]) { return false; } } return true; } else { return false; } } /** * Gets the senone sequence representing the given senones. * * @param stateid * is the array of senone state ids * @return the senone sequence associated with the states */ protected SenoneSequence getSenoneSequence(int[] stateid) { Senone[] senones = new Senone[stateid.length]; for (int i = 0; i < stateid.length; i++) { senones[i] = senonePool.get(stateid[i]); } return new SenoneSequence(senones); } /** * Loads the mixture weights (Binary). * * @param path * the path to the mixture weight file * @param floor * the minimum mixture weight allowed * @return a pool of mixture weights * @throws IOException * if an error occurs while loading the data * @throws URISyntaxException uri was incorrectly specified */ protected GaussianWeights loadMixtureWeights(String path, float floor) throws IOException, URISyntaxException { logger.fine("Loading mixture weights from: " + path); Properties props = new Properties(); DataInputStream dis = readS3BinaryHeader(path, props); String version = props.getProperty("version"); if (version == null || !version.equals(MIXW_FILE_VERSION)) { throw new IOException("Unsupported version in " + path); } String checksum = props.getProperty("chksum0"); boolean doCheckSum = (checksum != null && checksum.equals("yes")); resetChecksum(); int numStates = readInt(dis); int numStreams = readInt(dis); int numGaussiansPerState = readInt(dis); int numValues = readInt(dis); GaussianWeights mixtureWeights = new GaussianWeights(path, numStates, numGaussiansPerState, numStreams); logger.fine("Number of states " + numStates); logger.fine("Number of streams " + numStreams); logger.fine("Number of gaussians per state " + numGaussiansPerState); assert numValues == numStates * numStreams * numGaussiansPerState; for (int i = 0; i < numStates; i++) { for (int j = 0; j < numStreams; j++) { float[] logStreamMixtureWeight = readFloatArray(dis, numGaussiansPerState); Utilities.normalize(logStreamMixtureWeight); Utilities.floorData(logStreamMixtureWeight, floor); logMath.linearToLog(logStreamMixtureWeight); mixtureWeights.put(i, j, logStreamMixtureWeight); } } validateChecksum(dis, doCheckSum); dis.close(); return mixtureWeights; } /** * Loads the transition matrices (Binary). * * @param path * the path to the transitions matrices * @return a pool of transition matrices * @throws IOException * if an error occurs while loading the data * @throws URISyntaxException uri was incorrectly specified */ protected Pool<float[][]> loadTransitionMatrices(String path) throws IOException, URISyntaxException { logger.fine("Loading transition matrices from: " + path); Properties props = new Properties(); DataInputStream dis = readS3BinaryHeader(path, props); String version = props.getProperty("version"); if (version == null || !version.equals(TMAT_FILE_VERSION)) { throw new IOException("Unsupported version in " + path); } String checksum = props.getProperty("chksum0"); boolean doCheckSum = (checksum != null && checksum.equals("yes")); resetChecksum(); Pool<float[][]> pool = new Pool<float[][]>(path); int numMatrices = readInt(dis); int numRows = readInt(dis); int numStates = readInt(dis); int numValues = readInt(dis); assert numValues == numStates * numRows * numMatrices; for (int i = 0; i < numMatrices; i++) { float[][] tmat = new float[numStates][]; // last row should be zeros tmat[numStates - 1] = new float[numStates]; logMath.linearToLog(tmat[numStates - 1]); for (int j = 0; j < numRows; j++) { tmat[j] = readFloatArray(dis, numStates); Utilities.nonZeroFloor(tmat[j], 0f); Utilities.normalize(tmat[j]); logMath.linearToLog(tmat[j]); } pool.put(i, tmat); } validateChecksum(dis, doCheckSum); dis.close(); return pool; } /** * Loads the transform matrices (Binary). * * @param path * the path to the transform matrix * @return a transform matrix * @throws java.io.FileNotFoundException * if a file cannot be found * @throws java.io.IOException * if an error occurs while loading the data */ protected float[][] loadTransformMatrix(String path) throws IOException { logger.fine("Loading transform matrix from: " + path); Properties props = new Properties(); DataInputStream dis; try { dis = readS3BinaryHeader(path, props); } catch (URISyntaxException e) { throw new RuntimeException(e); } catch (IOException e) { return null; } String version = props.getProperty("version"); if (version == null || !version.equals(TRANSFORM_FILE_VERSION)) { throw new IOException("Unsupported version in " + path); } String checksum = props.getProperty("chksum0"); boolean doCheckSum = (checksum != null && checksum.equals("yes")); resetChecksum(); readInt(dis); int numRows = readInt(dis); int numValues = readInt(dis); int num = readInt(dis); assert num == numRows * numValues; float[][] result = new float[numRows][]; for (int i = 0; i < numRows; i++) { result[i] = readFloatArray(dis, numValues); } validateChecksum(dis, doCheckSum); dis.close(); return result; } public void clearGauScores() { if (phoneticTiedMixtures == null) return; for (MixtureComponentSet mixture : phoneticTiedMixtures) mixture.clearStoredScores(); } public void setGauScoresQueueLength(int scoresQueueLen) { if (phoneticTiedMixtures == null) return; for (MixtureComponentSet mixture : phoneticTiedMixtures) mixture.setScoreQueueLength(scoresQueueLen); } public Pool<float[]> getMeansPool() { return meansPool; } public Pool<float[][]> getMeansTransformationMatrixPool() { return meanTransformationMatrixPool; } public Pool<float[]> getMeansTransformationVectorPool() { return meanTransformationVectorPool; } public Pool<float[]> getVariancePool() { return variancePool; } public Pool<float[][]> getVarianceTransformationMatrixPool() { return varianceTransformationMatrixPool; } public Pool<float[]> getVarianceTransformationVectorPool() { return varianceTransformationVectorPool; } public GaussianWeights getMixtureWeights() { return mixtureWeights; } public Pool<float[][]> getTransitionMatrixPool() { return transitionsPool; } public float[][] getTransformMatrix() { return transformMatrix; } public Pool<Senone> getSenonePool() { return senonePool; } public int getLeftContextSize() { return CONTEXT_SIZE; } public int getRightContextSize() { return CONTEXT_SIZE; } public HMMManager getHMMManager() { return hmmManager; } public void logInfo() { logger.info("Loading tied-state acoustic model from: " + location); meansPool.logInfo(logger); variancePool.logInfo(logger); transitionsPool.logInfo(logger); senonePool.logInfo(logger); if (meanTransformationMatrixPool != null) meanTransformationMatrixPool.logInfo(logger); if (meanTransformationVectorPool != null) meanTransformationVectorPool.logInfo(logger); if (varianceTransformationMatrixPool != null) varianceTransformationMatrixPool.logInfo(logger); if (varianceTransformationVectorPool != null) varianceTransformationVectorPool.logInfo(logger); mixtureWeights.logInfo(logger); senonePool.logInfo(logger); logger.info("Context Independent Unit Entries: " + contextIndependentUnits.size()); hmmManager.logInfo(logger); } public Properties getProperties() { return modelProps; } protected Properties loadModelProps(String path) throws MalformedURLException, IOException, URISyntaxException { Properties props = new Properties(); BufferedReader reader = new BufferedReader(new InputStreamReader( getDataStream(path))); String line; while ((line = reader.readLine()) != null) { String[] tokens = line.split(" "); props.put(tokens[0], tokens[1]); } return props; } public void update(Transform transform, ClusteredDensityFileData clusters) { for (int index = 0; index < meansPool.size(); index++) { int transformClass = clusters.getClassIndex(index); float[] tmean = new float[getVectorLength()[0]]; float[] mean = meansPool.get(index); for (int i = 0; i < numStreams; i++) { for (int l = 0; l < getVectorLength()[i]; l++) { tmean[l] = 0; for (int m = 0; m < getVectorLength()[i]; m++) { tmean[l] += transform.getAs()[transformClass][i][l][m] * mean[m]; } tmean[l] += transform.getBs()[transformClass][i][l]; } System.arraycopy(tmean, 0, mean, 0, tmean.length); } } } }