/* * Copyright 2007 LORIA, France. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. */ package edu.cmu.sphinx.linguist.acoustic.tiedstate.HTK; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.*; /** * * @author Christophe Cerisara */ public class HMMSet { private GMMDiag g; private int nGaussians; float[][] trans; /** * contains HMMState instances */ public final List<HMMState> states; public final List<float[][]> transitions = new ArrayList<float[][]>(); public final Map<String, Integer> transNames = new HashMap<String, Integer>(); public Iterator<SingleHMM> get1phIt() { Iterator<SingleHMM> it = new Iterator<SingleHMM>() { int cur; public void remove() { } public SingleHMM next() { for (;;) { if (cur >= hmms.size()) return null; SingleHMM hmm = hmms.get(cur++); if (hmm.getName().indexOf('-') >= 0 || hmm.getName().indexOf('+') >= 0) continue; return hmm; } } public boolean hasNext() { return false; } }; return it; } public Iterator<SingleHMM> get3phIt() { Iterator<SingleHMM> it = new Iterator<SingleHMM>() { int cur; public void remove() { } public SingleHMM next() { for (;;) { if (cur >= hmms.size()) return null; SingleHMM hmm = hmms.get(cur++); if (!(hmm.getName().indexOf('-') >= 0 || hmm.getName() .indexOf('+') >= 0)) continue; return hmm; } } public boolean hasNext() { return false; } }; return it; } public int getStateIdx(HMMState st) { return st.gmmidx; } public int getHMMidx(SingleHMM hmm) { for (int i = 0; i < hmms.size(); i++) { SingleHMM h = hmms.get(i); if (h == hmm) return i; } return -1; } public int getNstates() { return gmms.size(); } public String[] getHMMnames() { String[] rep = new String[hmms.size()]; for (int i = 0; i < rep.length; i++) { SingleHMM h = hmms.get(i); rep[i] = h.getName(); } return rep; } /** * contains GMMDiag instances */ public final List<GMMDiag> gmms; /** * contains HMM instances */ public final List<SingleHMM> hmms; public int getNhmms() { return hmms.size(); } public int getNhmmsMono() { int n = 0; for (SingleHMM hmm : hmms) { if (!(hmm.getName().indexOf('-') >= 0 || hmm.getName().indexOf('+') >= 0)) n++; } return n; } public int getNhmmsTri() { int n = 0; for (SingleHMM hmm : hmms) { if (hmm.getName().indexOf('-') >= 0 || hmm.getName().indexOf('+') >= 0) n++; } return n; } public int getHMMIndex(SingleHMM h) { return hmms.indexOf(h); } /** * @param hmmidx index of the HMM (begins at 0) * @param stateidx index of the state WITHIN the HMM ! (begins at 1, as in MMF) * @return index of the state in the vector of all the states of the HMMSet */ public int getStateIdx(int hmmidx, int stateidx) { // TODO: store a table not to recalculate every time SingleHMM hmm; int nEmittingStates = 0; for (int i = 0; i < hmmidx; i++) { hmm = hmms.get(i); nEmittingStates += hmm.getNbEmittingStates(); } hmm = hmms.get(hmmidx); for (int i = 1; i < stateidx; i++) { if (hmm.isEmitting(i)) nEmittingStates++; } if (hmm.isEmitting(stateidx)) return nEmittingStates; // Don't add 1 since states are counted from // 0 else return -1; } public SingleHMM getHMM(int idx) { return hmms.get(idx); } public SingleHMM getHMM(String nom) { SingleHMM h = null; for (SingleHMM hmm : hmms) { h = hmm; if (h.getName().equals(nom)) break; } return h; } public HMMSet() { states = new ArrayList<HMMState>(); hmms = new ArrayList<SingleHMM>(); gmms = new ArrayList<GMMDiag>(); } public void loadHTK(String nomFich) { try { BufferedReader f = new BufferedReader(new FileReader(nomFich)); String s; for (;;) { s = f.readLine(); if (s == null) break; if (s.startsWith("~s")) { String nomEtat = s.substring(s.indexOf('"') + 1, s .lastIndexOf('"')); loadState(f, nomEtat, null); } else if (s.startsWith("~v")) { // variance floor: bypass } else if (s.startsWith("~t")) { String nomTrans = s.substring(s.indexOf('"') + 1, s .lastIndexOf('"')); loadTrans(f, nomTrans, null); } else if (s.startsWith("~h")) { String nomHMM = s.substring(s.indexOf('"') + 1, s .lastIndexOf('"')); if (nomHMM.toUpperCase().equals(nomHMM)) { System.out .println("WARNING: HMM is in lowercase, converting to upper"); } hmms.add(loadHMM(f, nomHMM.toUpperCase(), gmms)); } } f.close(); } catch (IOException e) { e.printStackTrace(); } } private String[][] tiedHMMs; public void loadTiedList(String nomFich) { try { BufferedReader f = new BufferedReader(new FileReader(nomFich)); String s; String[] ss; int ntiedstates = 0; for (;;) { s = f.readLine(); if (s == null) break; ss = s.split(" "); if (ss.length >= 2) { // We have a tiedstate ntiedstates++; } } tiedHMMs = new String[ntiedstates][2]; f.close(); f = new BufferedReader(new FileReader(nomFich)); for (int i = 0;;) { s = f.readLine(); if (s == null) break; ss = s.split(" "); if (ss.length >= 2) { // We have a tiedstate tiedHMMs[i][0] = ss[0]; tiedHMMs[i++][1] = ss[1]; } } f.close(); } catch (IOException e) { e.printStackTrace(); } } /** * WARNING To be compliant with sphinx3 models, we remove the first * non-emitting state ! * * @throws IOException */ private SingleHMM loadHMM(BufferedReader f, String n, List<GMMDiag> autresEtats) throws IOException { GMMDiag e = null; int curstate; String name = n; String s = ""; while (!s.startsWith("<NUMSTATES>")) { s = f.readLine(); } int nstates = Integer.parseInt(s.substring(s.indexOf(' ') + 1)); // Compliance with sphinx3 nstates--; SingleHMM theHMM = new SingleHMM(nstates); theHMM.setName(n); theHMM.hmmset = this; while (!s.startsWith("<STATE>")) s = f.readLine(); while (s.startsWith("<STATE>")) { curstate = Integer.parseInt(s.substring(s.indexOf(' ') + 1)); // Compliance with sphinx3 curstate--; s = f.readLine(); int gmmidx = -1; if (s.startsWith("~s")) { String nomEtat = s.substring(s.indexOf('"') + 1, s .lastIndexOf('"')); int i; for (i = 0; i < autresEtats.size(); i++) { e = autresEtats.get(i); if (e.nom.equals(nomEtat)) break; } gmmidx = i; if (i == autresEtats.size()) { System.err.println("Error creatiing HMM : state " + name + " not found"); System.exit(1); } } else { loadState(f, "", s); gmmidx = gmms.size() - 1; e = gmms.get(gmms.size() - 1); } HMMState st = new HMMState(e, new Lab(name, curstate)); st.gmmidx = gmmidx; states.add(st); theHMM.setState(curstate - 1, st); // -1 because in HTK HMMs are counted from 1 s = f.readLine(); // t eliminates the gconst because it is then recalculated! if (s.startsWith("<GCONST>")) s = f.readLine(); } if (s.startsWith("~t")) { // simple application of the String nomTrans = s.substring(s.indexOf('"') + 1, s .lastIndexOf('"')); int tridx = getTrans(nomTrans); theHMM.setTrans(tridx); } else { // The transitions are explicit if (!s.startsWith("<TRANSP>")) { System.err.println("Error reading model: missing transitions." + s); System.exit(1); } loadTrans(f, null, s); theHMM.setTrans(trans); } s = f.readLine(); if (!s.startsWith("<ENDHMM>")) { System.err.println("Error reading model: missing ENDHMM." + s); System.exit(1); } return theHMM; } private int loadTrans(BufferedReader f, String nomEtat, String prem) throws IOException { String s; int nstates = 0; if (prem != null) s = prem; else s = f.readLine().trim(); if (s.startsWith("<TRANSP>")) { nstates = Integer.parseInt(s.substring(s.indexOf(' ') + 1)); // Compliance with sphinx3 nstates--; } else { System.err.println("ERROR no TRANSP !"); System.exit(1); } String[] ss; trans = new float[nstates][nstates]; // Compliance with sphinx3 f.readLine(); for (int i = 0; i < nstates; i++) { s = f.readLine().trim(); ss = s.split(" "); for (int j = 0; j < nstates; j++) { // Compliance with sphinx3 trans[i][j] = Float.parseFloat(ss[j + 1]); } } if (nomEtat != null) { int tridx = transitions.size(); transNames.put(nomEtat, tridx); transitions.add(trans); return tridx; } else { return -1; // Application can recover the transitions in the pool } } private int getTrans(String trnom) { int tridx = transNames.get(trnom); return tridx; } private void loadState(BufferedReader f, String nomEtat, String prem) throws IOException { nGaussians = 1; String s; if (prem != null) s = prem; else s = f.readLine().trim(); if (s.startsWith("<NUMMIXES>")) { nGaussians = Integer.parseInt(s.substring(s.indexOf(' ') + 1)); s = f.readLine().trim(); } g = null; if (!s.startsWith("<MIXTURE>")) { // This model has single mixture if (nGaussians != 1) { System.err.println("Error loading model: number of mixtures is " + nGaussians + " while state " + s + " has 1 mixture."); System.exit(1); } loadHTKGauss(f, 0, s); g.setWeight(0, 1f); } else { String[] ss; for (int i = 0; i < nGaussians; i++) { if (i > 0) s = f.readLine().trim(); // Don't load GCONST if (s.startsWith("<GCONST>")) s = f.readLine().trim(); ss = s.split(" "); if (Integer.parseInt(ss[1]) != i + 1) { System.err.println("Error reading model: mixture conflict " + i + ' ' + s); System.exit(1); } loadHTKGauss(f, i, null); g.setWeight(i, Float.parseFloat(ss[2])); } } g.precomputeDistance(); g.setNom(nomEtat); gmms.add(g); } /** * Read until the last line of the file but it may leave one last line * so it can loose GCONST. * @throws java.io.IOException */ private void loadHTKGauss(BufferedReader f, int n, String prem) throws IOException { String s; String[] ss; if (prem != null) { // First line is taken into account s = prem; } else s = f.readLine().trim(); if (s.startsWith("<GCONST>")) s = f.readLine().trim(); if (s.startsWith("<RCLASS>")) s = f.readLine().trim(); if (!s.startsWith("<MEAN>")) { System.err.println("Error loading model: can't find <MEAN> ! " + s); System.exit(1); } int ncoefs = Integer.parseInt(s.substring(s.indexOf(' ') + 1)); if (g == null) g = new GMMDiag(nGaussians, ncoefs); s = f.readLine().trim(); ss = s.split(" "); if (ss.length != ncoefs) { System.err.println("Error loading model: incorrect number of coefficients " + ncoefs + ' ' + s + ' ' + ss[0] + ' ' + ss[39]); System.exit(1); } for (int i = 0; i < ncoefs; i++) { g.setMean(n, i, Float.parseFloat(ss[i])); } s = f.readLine().trim(); if (!s.startsWith("<VARIANCE>")) { System.err.println("Error loading model: missing <VARIANCE> ! " + s); System.exit(1); } s = f.readLine().trim(); ss = s.split(" "); if (ss.length != ncoefs) { System.err.println("Error loading model: incorrect number of coefficients " + ncoefs + ' ' + s); System.exit(1); } for (int i = 0; i < ncoefs; i++) { g.setVar(n, i, Float.parseFloat(ss[i])); } } public GMMDiag findState(Lab l) { while (true) { HMMState s = null; int i; for (i = 0; i < states.size(); i++) { s = states.get(i); if (s.getLab().isEqual(l)) break; } if (i < states.size()) { return s.gmm; } else { if (tiedHMMs != null) { // May be that state appears in the tied states for (i = 0; i < tiedHMMs.length; i++) { if (tiedHMMs[i][0].equals(l.getName())) { break; } } if (i < tiedHMMs.length) { l = new Lab(tiedHMMs[i][1], l.getState()); continue; } } System.err.println("WARNING: state is not found in hmmset " + l); return null; } } } }