/** * * Copyright 1999-2012 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */ package edu.cmu.sphinx.linguist.g2p; import java.io.IOException; import java.net.URL; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.Queue; import java.util.regex.Pattern; import edu.cmu.sphinx.fst.Arc; import edu.cmu.sphinx.fst.Fst; import edu.cmu.sphinx.fst.ImmutableFst; import edu.cmu.sphinx.fst.State; import edu.cmu.sphinx.fst.operations.ArcSort; import edu.cmu.sphinx.fst.operations.Compose; import edu.cmu.sphinx.fst.operations.ILabelCompare; import edu.cmu.sphinx.fst.operations.NShortestPaths; import edu.cmu.sphinx.fst.operations.OLabelCompare; import edu.cmu.sphinx.fst.operations.Project; import edu.cmu.sphinx.fst.operations.ProjectType; import edu.cmu.sphinx.fst.operations.RmEpsilon; import edu.cmu.sphinx.fst.semiring.Semiring; import edu.cmu.sphinx.fst.semiring.TropicalSemiring; import edu.cmu.sphinx.fst.utils.Utils; /** * The grapheme-to-phoneme (g2p) decoder * * @author John Salatas */ public class G2PConverter { // epsilon symbol String eps = "<eps>"; // end sequence symbol String se = "</s>"; // begin sequence symbol String sb = "<s>"; // skip symbol String skip = "_"; // separator symbol String tie = "|"; // set containing sequences to ignore HashSet<String> skipSeqs = new HashSet<String>(); // clusters ArrayList<String>[] clusters = null; // the g2p model ImmutableFst g2pmodel; // fst containing the epsilon filter for the compose operation Fst epsilonFilter; /** * Create a decoder by loading the serialized model from a specified URL * * @param g2pModelUrl * the URL of the serialized model * @throws IOException when model is not loaded */ public G2PConverter(URL g2pModelUrl) throws IOException { try { g2pmodel = ImmutableFst.loadModel(g2pModelUrl.openStream()); } catch (ClassNotFoundException e) { throw new IOException("Failed to load the model from " + g2pModelUrl, e); } init(); } /** * Create a decoder by loading the serialized model from a specified * filename * * @param g2pmodel_file * the filename of the serialized model */ public G2PConverter(String g2pmodel_file) { g2pmodel = ImmutableFst.loadModel(g2pmodel_file); init(); } /** * Initialize the decoder */ private void init() { skipSeqs.add(eps); skipSeqs.add(sb); skipSeqs.add(se); skipSeqs.add(skip); skipSeqs.add("-"); // keep an augmented copy (for compose) Compose.augment(0, g2pmodel, g2pmodel.getSemiring()); ArcSort.apply(g2pmodel, new ILabelCompare()); String[] isyms = g2pmodel.getIsyms(); loadClusters(isyms); // get epsilon filter for composition epsilonFilter = Compose.getFilter(g2pmodel.getIsyms(), g2pmodel.getSemiring()); ArcSort.apply(epsilonFilter, new ILabelCompare()); } /** * Phoneticize a word * * @param entry * the word to phoneticize transformed to an ArrayList of Strings * (each element hold a single character) * @param nbest * the number of distinct pronunciations to return * @return the pronunciation(s) of the input word */ public ArrayList<Path> phoneticize(ArrayList<String> entry, int nbest) { Fst efst = entryToFSA(entry); Semiring s = efst.getSemiring(); Compose.augment(1, efst, s); ArcSort.apply(efst, new OLabelCompare()); Fst result = Compose.compose(efst, epsilonFilter, s, true); ArcSort.apply(result, new OLabelCompare()); result = Compose.compose(result, g2pmodel, s, true); Project.apply(result, ProjectType.OUTPUT); if (nbest == 1) { result = NShortestPaths.get(result, 1, false); } else { // Requesting 10 times more best paths than what was asking // as there might be several paths resolving to same pronunciation // due to epsilon transitions. // I really hate cosmological constants :) result = NShortestPaths.get(result, nbest * 10, false); } // result = NShortestPaths.get(result, nbest, false); result = RmEpsilon.get(result); ArrayList<Path> paths = findAllPaths(result, nbest, skipSeqs, tie); return paths; } /** * Phoneticize a word * * @param word * the word to phoneticize * @param nbest * the number of distinct pronunciations to return * @return the pronunciation(s) of the input word */ public ArrayList<Path> phoneticize(String word, int nbest) { ArrayList<String> entry = new ArrayList<String>(word.length()); for (int i = 0; i < word.length(); i++) { String ch = word.substring(i, i + 1); if (Utils.getIndex(g2pmodel.getIsyms(), ch) >= 0) { entry.add(ch); } } return phoneticize(entry, nbest); } /** * Transforms an input spelling/pronunciation into an equivalent FSA, adding * extra arcs as needed to accommodate clusters. * * @param entry * the input vector * @return the created fst */ private Fst entryToFSA(ArrayList<String> entry) { TropicalSemiring ts = new TropicalSemiring(); Fst efst = new Fst(ts); State s = new State(ts.zero()); efst.addState(s); efst.setStart(s); // Build the basic FSA for (int i = 0; i < entry.size() + 1; i++) { s = new State(ts.zero()); efst.addState(s); if (i >= 1) { int symIndex = Utils.getIndex(g2pmodel.getIsyms(), entry.get(i - 1)); efst.getState(i).addArc(new Arc(symIndex, symIndex, 0.f, s)); } else if (i == 0) { int symIndex = Utils.getIndex(g2pmodel.getIsyms(), sb); efst.getStart().addArc(new Arc(symIndex, symIndex, 0.f, s)); } if (i == entry.size()) { State s1 = new State(ts.zero()); efst.addState(s1); int symIndex = Utils.getIndex(g2pmodel.getIsyms(), se); s.addArc(new Arc(symIndex, symIndex, 0.f, s1)); s1.setFinalWeight(0.f); } } // Add any cluster arcs for (int value = 0; value < clusters.length; value++) { ArrayList<String> cluster = clusters[value]; if (cluster != null) { int start = 0; int k = 0; while (k != -1) { k = Utils.search(entry, cluster, start); if (k != -1) { State from = efst.getState(start + k + 1); from.addArc(new Arc(value, value, 0.f, efst .getState(start + k + cluster.size() + 1))); start = start + k + cluster.size(); } } } } efst.setIsyms(g2pmodel.getIsyms()); efst.setOsyms(g2pmodel.getIsyms()); return efst; } /** * Finds nbest paths in an Fst returned by NShortestPaths operation * * @param fst * the input fst * @param nbest * the number of paths to return * @param skipSeqs * the sequences to ignore * @param tie * the separator symbol * @return the paths */ @SuppressWarnings("unchecked") private ArrayList<Path> findAllPaths(Fst fst, int nbest, HashSet<String> skipSeqs, String tie) { Semiring semiring = fst.getSemiring(); // ArrayList<Path> finalPaths = new ArrayList<Path>(); HashMap<String, Path> finalPaths = new HashMap<String, Path>(); HashMap<State, Path> paths = new HashMap<State, Path>(); Queue<State> queue = new LinkedList<State>(); Path p = new Path(fst.getSemiring()); p.setCost(semiring.one()); paths.put(fst.getStart(), p); queue.add(fst.getStart()); String[] osyms = fst.getOsyms(); while (!queue.isEmpty()) { State s = queue.remove(); Path currentPath = paths.get(s); if (s.getFinalWeight() != semiring.zero()) { String pathString = currentPath.getPath().toString(); if (finalPaths.containsKey(pathString)) { // path already exist. update its cost Path old = finalPaths.get(pathString); if (old.getCost() > currentPath.getCost()) { finalPaths.put(pathString, currentPath); } } else { finalPaths.put(pathString, currentPath); } } int numArcs = s.getNumArcs(); for (int j = 0; j < numArcs; j++) { Arc a = s.getArc(j); p = new Path(fst.getSemiring()); Path cur = paths.get(s); p.setCost(cur.getCost()); p.setPath((ArrayList<String>) cur.getPath().clone()); String sym = osyms[a.getOlabel()]; String[] symsArray = sym.split("\\" + tie); for (int i = 0; i < symsArray.length; i++) { String phone = symsArray[i]; if (!skipSeqs.contains(phone)) { p.getPath().add(phone); } } p.setCost(semiring.times(p.getCost(), a.getWeight())); State nextState = a.getNextState(); paths.put(nextState, p); if (!queue.contains(nextState)) { queue.add(nextState); } } } ArrayList<Path> res = new ArrayList<Path>(); for (Path path : finalPaths.values()) { res.add(path); } Collections.sort(res, new PathComparator()); int numPaths = res.size(); for (int i = nbest; i < numPaths; i++) { res.remove(res.size() - 1); } return res; } /** * Initialize clusters */ @SuppressWarnings("unchecked") private void loadClusters(String[] syms) { clusters = new ArrayList[syms.length]; for (int i = 0; i < syms.length; i++) { clusters[i] = null; } for (int i = 2; i < syms.length; i++) { String sym = syms[i]; if (sym.contains(tie)) { String split[] = sym.split(Pattern.quote(tie)); ArrayList<String> cluster = new ArrayList<String>( Arrays.asList(split)); clusters[i] = cluster; } } } }