package edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi; import java.util.Arrays; import edu.cmu.sphinx.frontend.Data; import edu.cmu.sphinx.frontend.FloatData; import edu.cmu.sphinx.linguist.acoustic.tiedstate.MixtureComponent; import edu.cmu.sphinx.linguist.acoustic.tiedstate.ScoreCachingSenone; import edu.cmu.sphinx.util.LogMath; /** * Gaussian Mixture Model with diagonal covariances. * * @see DiagGmm class in Kaldi. */ @SuppressWarnings("serial") public class DiagGmm extends ScoreCachingSenone { private int id; private float[] gconsts; private float[] invVars; private float[] meansInvVars; /** * Constructs new mixture model. * * @param id identifier of this GMM as defined in the model * @param parser text format parser */ public DiagGmm(int id, KaldiTextParser parser) { this.id = id; parser.expectToken("<DiagGMM>"); parser.expectToken("<GCONSTS>"); gconsts = parser.getFloatArray(); parser.expectToken("<WEIGHTS>"); // Do not use weights as they are in gconsts. parser.getFloatArray(); parser.expectToken("<MEANS_INVVARS>"); meansInvVars = parser.getFloatArray(); parser.expectToken("<INV_VARS>"); invVars = parser.getFloatArray(); parser.expectToken("</DiagGMM>"); } /** * Convenient method if 32-bit ID is required. * * Kaldi model uses 32-bit integer to store GMM id while Senone contract * imposes long type. This method is present to avaoid type cast when * working in the Kaldi domain. * @return the ID of gmm */ public int getId() { return id; } @Override public float calculateScore(Data data) { float logTotal = LogMath.LOG_ZERO; LogMath logMath = LogMath.getLogMath(); for (Float mixtureScore : calculateComponentScore(data)) logTotal = logMath.addAsLinear(logTotal, mixtureScore); return logTotal; } public float[] calculateComponentScore(Data data) { float[] features = FloatData.toFloatData(data).getValues(); int dim = meansInvVars.length / gconsts.length; if (features.length != dim) { String fmt = "feature vector must be of length %d, got %d"; String msg = String.format(fmt, dim, features.length); throw new IllegalArgumentException(msg); } float[] likelihoods = Arrays.copyOf(gconsts, gconsts.length); for (int i = 0; i < likelihoods.length; ++i) { for (int j = 0; j < features.length; ++j) { int k = i * features.length + j; likelihoods[i] += meansInvVars[k] * features[j]; likelihoods[i] -= .5f * invVars[k] * features[j] * features[j]; } likelihoods[i] = LogMath.getLogMath().lnToLog(likelihoods[i]); } return likelihoods; } public long getID() { return id; } public void dump(String msg) { System.out.format("%s DiagGmm: ID %d\n", msg, id); } public MixtureComponent[] getMixtureComponents() { return null; } public float[] getLogMixtureWeights() { return null; } }