/* * Copyright 2007 LORIA, France. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. */ package edu.cmu.sphinx.linguist.acoustic.tiedstate.HTK; import edu.cmu.sphinx.util.LogMath; import java.io.BufferedReader; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.PrintWriter; import java.util.StringTokenizer; /** * This a producer for observations, it outputs the log likelihoods for * guassians * * @author Christophe Cerisara * */ public class GMMDiag { public int nT; public String nom; public LogMath logMath; private int ncoefs; private int ngauss; protected float[] weights; protected float[][] means; protected float[][] covar; private float[] logPreComputedGaussianFactor; protected float[] loglikes; public GMMDiag() { } public GMMDiag(int ng, int nc) { ngauss = ng; ncoefs = nc; allocate(); } public int getNgauss() { return ngauss; } public float getWeight(int i) { return (float) logMath.logToLinear(weights[i]); } public float getVar(int i, int j) { return -1f / (2f * covar[i][j]); } public void setWeight(int i, float w) { if (weights == null) weights = new float[ngauss]; weights[i] = logMath.linearToLog(w); } public void setVar(int i, int j, float v) { if (v <= 0) // This is not a error, because you can use the GMM just to store // values and retrieve them later. // TODO: good constant is not very clean, because we must still have variance > 0 System.err.println("WARNING: setVar " + v); covar[i][j] = -1f / (2f * v); } public void setMean(int i, int j, float v) { means[i][j] = v; } public float getMean(int i, int j) { return means[i][j]; } /** * Saves in proprietary format * @param name name of file to save */ public void save(String name) { try { PrintWriter fout = new PrintWriter(new FileWriter(name)); fout.println(ngauss + " " + ncoefs); for (int i = 0; i < ngauss; i++) { fout.println("gauss " + i + ' ' + getWeight(i)); for (int j = 0; j < ncoefs; j++) fout.print(means[i][j] + " "); fout.println(); for (int j = 0; j < ncoefs; j++) fout.print(getVar(i, j) + " "); fout.println(); } fout.println(nT); fout.close(); } catch (IOException e) { e.printStackTrace(); } } /** * Load from text proprietary format * @param name filename to load from */ public void load(String name) { try { BufferedReader fin = new BufferedReader(new FileReader(name)); String s = fin.readLine(); String[] ss = s.split(" "); ngauss = Integer.parseInt(ss[0]); ncoefs = Integer.parseInt(ss[1]); allocate(); for (int i = 0; i < ngauss; i++) { s = fin.readLine(); ss = s.split(" "); if (!ss[0].equals("gauss") || Integer.parseInt(ss[1]) != i) { System.err.println("Error loading GMM " + s + ' ' + i); System.exit(1); } setWeight(i, Float.parseFloat(ss[2])); // means s = fin.readLine(); ss = s.split(" "); for (int j = 0; j < ncoefs; j++) { setMean(i, j, Float.parseFloat(ss[j])); } // covariances s = fin.readLine(); ss = s.split(" "); for (int j = 0; j < ncoefs; j++) { setVar(i, j, Float.parseFloat(ss[j])); } } s = fin.readLine(); if (s != null) { // can be added to store the amount of data on which the GMM has been // learned nT = Integer.parseInt(s); } fin.close(); precomputeDistance(); } catch (IOException e) { e.printStackTrace(); } } public void saveHTK(String nomFich, String nomHMM) { saveHTK(nomFich, nomHMM, "<USER>"); } public PrintWriter saveHTKheader(String nomFich, String parmKind) { try { PrintWriter fout = new PrintWriter(new FileWriter(nomFich)); fout.println("~o"); fout.println("<HMMSETID> tree"); fout.println("<STREAMINFO> 1 " + getNcoefs()); fout.println("<VECSIZE> " + getNcoefs() + "<NULLD>" + parmKind + "<DIAGC>"); fout.println("~r \"rtree_1\""); fout.println("<REGTREE> 1"); fout.println("<TNODE> 1 " + getNgauss()); return fout; } catch (IOException e) { e.printStackTrace(); return null; } } public void saveHTKState(PrintWriter fout) { fout.println("<NUMMIXES> " + getNgauss()); for (int i = 1; i <= getNgauss(); i++) { fout.println("<MIXTURE> " + i + ' ' + getWeight(i - 1)); fout.println("<RCLASS> 1"); fout.println("<MEAN> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getMean(i - 1, j) + " "); } fout.println(); fout.println("<VARIANCE> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getVar(i - 1, j) + " "); } fout.println(); } } public void saveHTKtailer(int nstates, PrintWriter fout) { fout.println("<TRANSP> " + nstates); // First state is non emitting for (int j = 0; j < nstates; j++) fout.print("0 "); fout.println(); for (int i = 1; i < nstates - 1; i++) { for (int j = 0; j < i; j++) fout.print("0 "); fout.print("0.5 0.5"); for (int j = i + 3; j < nstates; j++) fout.print("0 "); } fout.println(); fout.println("0 0 0"); fout.println("<ENDHMM>"); } public void saveHTK(String nomFich, String nomHMM, String parmKind) { try { PrintWriter fout = new PrintWriter(new FileWriter(nomFich)); fout.println("~o"); fout.println("<HMMSETID> tree"); fout.println("<STREAMINFO> 1 " + getNcoefs()); fout.println("<VECSIZE> " + getNcoefs() + "<NULLD>" + parmKind + "<DIAGC>"); fout.println("~r \"rtree_1\""); fout.println("<REGTREE> 1"); fout.println("<TNODE> 1 " + getNgauss()); fout.println("~h \"" + nomHMM + '\"'); fout.println("<BEGINHMM>"); fout.println("<NUMSTATES> 3"); fout.println("<STATE> 2"); fout.println("<NUMMIXES> " + getNgauss()); for (int i = 1; i <= getNgauss(); i++) { fout.println("<MIXTURE> " + i + ' ' + getWeight(i - 1)); fout.println("<RCLASS> 1"); fout.println("<MEAN> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getMean(i - 1, j) + " "); } fout.println(); fout.println("<VARIANCE> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getVar(i - 1, j) + " "); } fout.println(); } fout.println("<TRANSP> 3"); fout.println("0 1 0"); fout.println("0 0.7 0.3"); fout.println("0 0 0"); fout.println("<ENDHMM>"); fout.close(); } catch (IOException e) { e.printStackTrace(); } } public void loadHTK(String nom) { try { BufferedReader fin = new BufferedReader(new FileReader(nom)); String s, s2; StringTokenizer st; ngauss = 0; ncoefs = 0; for (;;) { s = fin.readLine(); if (s == null) break; if (s.contains("<MEAN>")) { ngauss++; if (ncoefs == 0) { st = new StringTokenizer(s); st.nextToken(); ncoefs = Integer.parseInt(st.nextToken()); } } } fin.close(); allocate(); fin = new BufferedReader(new FileReader(nom)); for (int g = 0;;) { s = fin.readLine(); if (s == null) break; if (s.contains("<MEAN>")) { s = fin.readLine(); st = new StringTokenizer(s); for (int c = 0; st.hasMoreTokens(); c++) { s2 = st.nextToken(); setMean(g, c, Float.parseFloat(s2)); } s = fin.readLine(); if (!s.contains("<VARIANCE>")) { fin.close(); throw new IOException(); } s = fin.readLine(); st = new StringTokenizer(s); for (int c = 0; st.hasMoreTokens(); c++) { s2 = st.nextToken(); setVar(g, c, Float.parseFloat(s2)); } g++; } } fin.close(); precomputeDistance(); } catch (IOException e) { e.printStackTrace(); } } public void loadScaleKMeans(String nom) { String s; String[] ss; int ng = 0; try { BufferedReader fin = new BufferedReader(new FileReader(nom)); for (;; ng++) { s = fin.readLine(); if (s == null) break; } ngauss = ng / 2; fin.close(); fin = new BufferedReader(new FileReader(nom)); s = fin.readLine(); ss = s.split(" "); ncoefs = ss.length - 1; fin.close(); fin = new BufferedReader(new FileReader(nom)); allocate(); nT = 0; for (int i = 0; i < ngauss; i++) { s = fin.readLine(); ss = s.split(" "); weights[i] = Float.parseFloat(ss[0]); nT += weights[i]; for (int j = 0; j < ncoefs; j++) { setMean(i, j, Float.parseFloat(ss[j + 1])); } s = fin.readLine(); ss = s.split(" "); for (int j = 0; j < ncoefs; j++) { setVar(i, j, Float.parseFloat(ss[j])); } } for (int i = 0; i < ngauss; i++) { setWeight(i, weights[i] / nT); } fin.close(); precomputeDistance(); } catch (IOException e) { e.printStackTrace(); } } private void allocateWeights() { logMath = LogMath.getLogMath(); weights = new float[ngauss]; for (int i = 0; i < ngauss; i++) { setWeight(i, 1f / ngauss); } } public void precomputeDistance() { for (int gidx = 0; gidx < ngauss; gidx++) { float fact = 0.0f; for (int i = 0; i < ncoefs; i++) { fact += logMath.linearToLog(getVar(gidx, i)); } fact += logMath.linearToLog(2.0 * Math.PI) * ncoefs; logPreComputedGaussianFactor[gidx] = fact * 0.5f; } } private void allocate() { if (weights == null) allocateWeights(); if (means == null) { loglikes = new float[ngauss]; means = new float[ngauss][ncoefs]; covar = new float[ngauss][ncoefs]; logPreComputedGaussianFactor = new float[ngauss]; } } /* * Log likelihood calculation */ private static final float distFloor = -Float.MAX_VALUE; public void computeLogLikes(float[] data) { float logDval1gauss = 0f; for (int gidx = 0; gidx < ngauss; gidx++) { logDval1gauss = 0f; for (int i = 0; i < data.length; i++) { float logDiff = data[i] - means[gidx][i]; logDval1gauss += logDiff * logDiff * covar[gidx][i]; } logDval1gauss -= logPreComputedGaussianFactor[gidx]; if (Float.isNaN(logDval1gauss)) { System.err.println("gs2 is Nan, converting to 0 debug " + gidx + ' ' + logPreComputedGaussianFactor[gidx] + ' ' + means[gidx][0] + ' ' + covar[gidx][0]); logDval1gauss = LogMath.LOG_ZERO; } if (logDval1gauss < distFloor) { logDval1gauss = distFloor; } // Including apriori probability for each gaussian loglikes[gidx] = weights[gidx] + logDval1gauss; } } /** * Calculate log probability of the observation * must be called AFTER next() ! * * @return log likelihood */ public float getLogLike() { float sc = loglikes[0]; for (int i = 1; i < ngauss; i++) { sc = logMath.addAsLinear(sc, loglikes[i]); } return sc; } /** * must be called AFTER next() * * @return best gaussian */ public int getWinningGauss() { int imax = 0; for (int i = 1; i < ngauss; i++) { if (loglikes[i] > loglikes[imax]) imax = i; } return imax; } public int getNcoefs() { return ncoefs; } /* * Manipulations with HMMs */ public GMMDiag getMarginal(boolean[] mask) { int nc = 0; for (boolean flag : mask) if (flag) nc++; GMMDiag g = new GMMDiag(getNgauss(), nc); int curc = 0; for (int j = 0; j < ncoefs; j++) { if (mask[j]) { for (int i = 0; i < ngauss; i++) { g.setMean(i, curc, getMean(i, j)); g.setVar(i, curc, getVar(i, j)); } curc++; } } for (int i = 0; i < ngauss; i++) { g.setWeight(i, getWeight(i)); } g.precomputeDistance(); return g; } /** * * @param g second GMM for the merge * @param w1 weight of the first GMM for the merge * @return gaussian */ public GMMDiag merge(GMMDiag g, float w1) { GMMDiag res = new GMMDiag(getNgauss() + g.getNgauss(), getNcoefs()); for (int i = 0; i < getNgauss(); i++) { System.arraycopy(means[i], 0, res.means[i], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[i], 0, getNcoefs()); res.setWeight(i, getWeight(i) * w1); } for (int i = 0; i < g.getNgauss(); i++) { System.arraycopy(g.means[i], 0, res.means[ngauss + i], 0, getNcoefs()); System.arraycopy(g.covar[i], 0, res.covar[ngauss + i], 0, getNcoefs()); res.setWeight(ngauss + i, g.getWeight(i) * (1f - w1)); } res.precomputeDistance(); return res; } /** * extracts ONE gaussian from the GMM * * @param i position * @return gaussian */ public GMMDiag getGauss(int i) { GMMDiag res = new GMMDiag(1, getNcoefs()); System.arraycopy(means[i], 0, res.means[0], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[0], 0, getNcoefs()); res.setWeight(0, 1); res.precomputeDistance(); return res; } public void setNom(String s) { nom = s; } /** * 2 GMMs are considered to be equal when all of their parameters do not * differ from more than 1% * @param g second gmm to compare to * @return if GMMs are equal */ public boolean isEqual(GMMDiag g) { if (getNgauss() != g.getNgauss()) return false; if (getNgauss() != g.getNcoefs()) return false; for (int i = 0; i < getNgauss(); i++) { if (isDiff(getWeight(i), g.getWeight(i))) return false; for (int j = 0; j < getNcoefs(); j++) { if (isDiff(getMean(i, j), g.getMean(i, j))) return false; if (isDiff(getVar(i, j), g.getVar(i, j))) return false; } } return true; } private boolean isDiff(float a, float b) { return Math.abs(1 - b / a) > 0.01; } @Override public String toString() { StringBuilder sb = new StringBuilder (); for (int i = 0; i < getNgauss(); i++) { sb.append(getMean(i, 0)).append(' ').append(getVar(i, 0)).append( '\n'); } return sb.toString(); } }