package edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import edu.cmu.sphinx.util.LogMath; final class HmmState { private final int id; private final int pdfClass; private final List<Integer> transitions; public HmmState(int id, int pdfClass, Collection<Integer> transitions) { this.id = id; this.pdfClass = pdfClass; this.transitions = new ArrayList<Integer>(transitions); } public int getId() { return id; } public int getPdfClass() { return pdfClass; } public List<Integer> getTransitions() { return transitions; } public int size() { return transitions.size(); } @Override public String toString() { return String.format("HmmSate {%d, %d, %s}", id, pdfClass, transitions); } } final class Triple { private int phone; private int hmmState; private int pdf; public Triple(int phone, int hmmState, int pdf) { this.phone = phone; this.hmmState = hmmState; this.pdf = pdf; } @Override public boolean equals(Object object) { if (!(object instanceof Triple)) return false; Triple other = (Triple) object; return phone == other.phone && hmmState == other.hmmState && pdf == other.pdf; } @Override public int hashCode() { return 31 * (31 * phone + hmmState) + pdf; } @Override public String toString() { return String.format("Triple {%d, %d, %d}", phone, hmmState, pdf); } } /** * Represents transition model of a Kaldi acoustic model. */ public class TransitionModel { private Map<Integer, List<HmmState>> phoneStates; private Map<Triple, Integer> transitionStates; private float[] logProbabilities; /** * Loads transition model using provided parser. * * @param parser parser */ public TransitionModel(KaldiTextParser parser) { parser.expectToken("<TransitionModel>"); parseTopology(parser); parser.expectToken("<Triples>"); transitionStates = new HashMap<Triple, Integer>(); int numTriples = parser.getInt(); int transitionId = 1; for (int i = 0; i < numTriples; ++i) { int phone = parser.getInt(); int hmmState = parser.getInt(); int pdf = parser.getInt(); Triple triple = new Triple(phone, hmmState, pdf); transitionStates.put(triple, transitionId); transitionId += phoneStates.get(phone).get(hmmState).getTransitions().size(); } parser.expectToken("</Triples>"); parser.expectToken("<LogProbs>"); logProbabilities = parser.getFloatArray(); parser.expectToken("</LogProbs>"); parser.expectToken("</TransitionModel>"); LogMath logMath = LogMath.getLogMath(); for (int i = 0; i < logProbabilities.length; ++i) logProbabilities[i] = logMath.lnToLog(logProbabilities[i]); } private void parseTopology(KaldiTextParser parser) { parser.expectToken("<Topology>"); phoneStates = new HashMap<Integer, List<HmmState>>(); String token; while ("<TopologyEntry>".equals(token = parser.getToken())) { parser.assertToken("<TopologyEntry>", token); parser.expectToken("<ForPhones>"); List<Integer> phones = new ArrayList<Integer>(); while (!"</ForPhones>".equals(token = parser.getToken())) phones.add(Integer.parseInt(token)); List<HmmState> states = new ArrayList<HmmState>(3); while ("<State>".equals(token = parser.getToken())) { // Skip state number. int id = parser.getInt(); token = parser.getToken(); if ("<PdfClass>".equals(token)) { int pdfClass = parser.getInt(); List<Integer> transitions = new ArrayList<Integer>(); while ("<Transition>".equals(token = parser.getToken())) { transitions.add(parser.getInt()); // Skip initial probability. parser.getToken(); } parser.assertToken("</State>", token); states.add(new HmmState(id, pdfClass, transitions)); } } for (Integer id : phones) phoneStates.put(id, states); } parser.assertToken("</Topology>", token); } /** * Returns transition matrix for the given context. * * @param phone central phone in the context * @param pdfs array of pdf identifiers of the context units * * @return * 4 by 4 matrix where cell i,j contains probability in {@link LogMath} * domain of transition from state i to state j */ public float[][] getTransitionMatrix(int phone, int[] pdfs) { // TODO: use variable size float[][] transitionMatrix = new float[4][4]; Arrays.fill(transitionMatrix[3], LogMath.LOG_ZERO); for (HmmState state : phoneStates.get(phone)) { int stateId = state.getId(); Arrays.fill(transitionMatrix[stateId], LogMath.LOG_ZERO); Triple triple = new Triple(phone, stateId, pdfs[stateId]); int i = transitionStates.get(triple); for (Integer j : state.getTransitions()) transitionMatrix[stateId][j] = logProbabilities[i++]; } return transitionMatrix; } }