package edu.cmu.sphinx.decoder.adaptation;
import edu.cmu.sphinx.api.SpeechResult;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.HMMSearchState;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import edu.cmu.sphinx.util.LogMath;
/**
* This class is used for estimating a MLLR transform for each cluster of data.
* The clustering must be previously performed using
* ClusteredDensityFileData.java
*/
public class Stats {
// Minimum number of frames to perform estimation
private static final int MIN_FRAMES = 300;
private ClusteredDensityFileData means;
private double[][][][][] regLs;
private double[][][][] regRs;
private int nClusters;
private Sphinx3Loader loader;
private float varFlor;
private LogMath logMath = LogMath.getLogMath();
private int nFrames;
public Stats(Loader loader, ClusteredDensityFileData means) {
this.loader = (Sphinx3Loader) loader;
this.nClusters = means.getNumberOfClusters();
this.means = means;
this.varFlor = 1e-5f;
this.invertVariances();
this.init();
this.nFrames = 0;
}
private void init() {
int len = loader.getVectorLength()[0];
this.regLs = new double[nClusters][][][][];
this.regRs = new double[nClusters][][][];
for (int i = 0; i < nClusters; i++) {
this.regLs[i] = new double[loader.getNumStreams()][][][];
this.regRs[i] = new double[loader.getNumStreams()][][];
for (int j = 0; j < loader.getNumStreams(); j++) {
len = loader.getVectorLength()[j];
this.regLs[i][j] = new double[len][len + 1][len + 1];
this.regRs[i][j] = new double[len][len + 1];
}
}
}
public ClusteredDensityFileData getClusteredData() {
return this.means;
}
public double[][][][][] getRegLs() {
return regLs;
}
public double[][][][] getRegRs() {
return regRs;
}
/**
* Used for inverting variances.
*/
private void invertVariances() {
for (int i = 0; i < loader.getNumStates(); i++) {
for (int k = 0; k < loader.getNumGaussiansPerState(); k++) {
for (int l = 0; l < loader.getVectorLength()[0]; l++) {
if (loader.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l] <= 0.) {
this.loader.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l] = (float) 0.5;
} else if (loader.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l] < varFlor) {
this.loader.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l] = (float) (1. / varFlor);
} else {
this.loader.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l] = (float) (1. / loader
.getVariancePool().get(i * loader.getNumGaussiansPerState() + k)[l]);
}
}
}
}
}
/**
* Computes posterior values for the each component.
*
* @param componentScores
* from which the posterior values are computed.
* @param numStreams
* Number of feature streams
* @return posterior values for all components.
*/
private float[] computePosterios(float[] componentScores, int numStreams) {
float[] posteriors = componentScores;
int step = componentScores.length / numStreams;
int startIdx = 0;
for (int i = 0; i < numStreams; i++) {
float max = posteriors[startIdx];
for (int j = startIdx + 1; j < startIdx + step; j++) {
if (posteriors[j] > max) {
max = posteriors[j];
}
}
for (int j = startIdx; j < startIdx + step; j++) {
posteriors[j] = (float) logMath.logToLinear(posteriors[j] - max);
}
startIdx += step;
}
return posteriors;
}
/**
* This method is used for directly collect and use counts. The counts are
* collected and stored separately for each cluster.
*
* @param result
* Result object to collect counts from.
* @throws Exception
* if something went wrong
*/
public void collect(SpeechResult result) throws Exception {
Token token = result.getResult().getBestToken();
float[] componentScore, featureVector, posteriors, tmean;
int[] len;
float dnom, wtMeanVar, wtDcountVar, wtDcountVarMean, mean;
int mId, cluster;
int numStreams, gauPerState;
if (token == null)
throw new Exception("Best token not found!");
do {
FloatData feature = (FloatData) token.getData();
SearchState ss = token.getSearchState();
if (!(ss instanceof HMMSearchState && ss.isEmitting())) {
token = token.getPredecessor();
continue;
}
nFrames++;
componentScore = token.calculateComponentScore(feature);
featureVector = FloatData.toFloatData(feature).getValues();
mId = (int) ((HMMSearchState) token.getSearchState()).getHMMState().getMixtureId();
if (loader instanceof Sphinx3Loader && ((Sphinx3Loader) loader).hasTiedMixtures())
// use CI phone ID for tied mixture model
mId = ((Sphinx3Loader) loader).getSenone2Ci()[mId];
len = loader.getVectorLength();
numStreams = loader.getNumStreams();
gauPerState = loader.getNumGaussiansPerState();
posteriors = this.computePosterios(componentScore, numStreams);
int featVectorStartIdx = 0;
for (int i = 0; i < numStreams; i++) {
for (int j = 0; j < gauPerState; j++) {
cluster = means.getClassIndex(mId * numStreams * gauPerState + i * gauPerState + j);
dnom = posteriors[i * gauPerState + j];
if (dnom > 0.) {
tmean = loader.getMeansPool().get(mId * numStreams * gauPerState + i * gauPerState + j);
for (int k = 0; k < len[i]; k++) {
mean = posteriors[i * gauPerState + j] * featureVector[k + featVectorStartIdx];
wtMeanVar = mean
* loader.getVariancePool().get(mId * numStreams * gauPerState + i * gauPerState + j)[k];
wtDcountVar = dnom
* loader.getVariancePool().get(mId * numStreams * gauPerState + i * gauPerState + j)[k];
for (int p = 0; p < len[i]; p++) {
wtDcountVarMean = wtDcountVar * tmean[p];
for (int q = p; q < len[i]; q++) {
regLs[cluster][i][k][p][q] += wtDcountVarMean * tmean[q];
}
regLs[cluster][i][k][p][len[i]] += wtDcountVarMean;
regRs[cluster][i][k][p] += wtMeanVar * tmean[p];
}
regLs[cluster][i][k][len[i]][len[i]] += wtDcountVar;
regRs[cluster][i][k][len[i]] += wtMeanVar;
}
}
}
featVectorStartIdx += len[i];
}
token = token.getPredecessor();
} while (token != null);
}
/**
* Fill lower part of Legetter's set of G matrices.
*/
public void fillRegLowerPart() {
for (int i = 0; i < this.nClusters; i++) {
for (int j = 0; j < loader.getNumStreams(); j++) {
for (int l = 0; l < loader.getVectorLength()[j]; l++) {
for (int p = 0; p <= loader.getVectorLength()[j]; p++) {
for (int q = p + 1; q <= loader.getVectorLength()[j]; q++) {
regLs[i][j][l][q][p] = regLs[i][j][l][p][q];
}
}
}
}
}
}
public Transform createTransform() {
if (nFrames < MIN_FRAMES * nClusters) {
return null;
}
Transform transform = new Transform(loader, nClusters);
transform.update(this);
return transform;
}
public int getFrames() {
return nFrames;
}
}