/* * Copyright 2014 Alpha Cephei Inc. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. */ package edu.cmu.sphinx.api; import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; import java.net.URL; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.TreeMap; import java.util.logging.Logger; import edu.cmu.sphinx.alignment.LongTextAligner; import edu.cmu.sphinx.alignment.SimpleTokenizer; import edu.cmu.sphinx.alignment.TextTokenizer; import edu.cmu.sphinx.linguist.language.grammar.AlignerGrammar; import edu.cmu.sphinx.linguist.language.ngram.DynamicTrigramModel; import edu.cmu.sphinx.recognizer.Recognizer; import edu.cmu.sphinx.result.Result; import edu.cmu.sphinx.result.WordResult; import edu.cmu.sphinx.util.Range; import edu.cmu.sphinx.util.TimeFrame; public class SpeechAligner { private final Logger logger = Logger.getLogger(getClass().getSimpleName()); private static final int TUPLE_SIZE = 3; private final Context context; private final Recognizer recognizer; private final AlignerGrammar grammar; private final DynamicTrigramModel languageModel; private TextTokenizer tokenizer; public SpeechAligner(String amPath, String dictPath, String g2pPath) throws MalformedURLException, IOException { Configuration configuration = new Configuration(); configuration.setAcousticModelPath(amPath); configuration.setDictionaryPath(dictPath); context = new Context(configuration); if (g2pPath != null) { context.setLocalProperty("dictionary->g2pModelPath", g2pPath); context.setLocalProperty("dictionary->g2pMaxPron", "2"); } context.setLocalProperty("lexTreeLinguist->languageModel", "dynamicTrigramModel"); recognizer = context.getInstance(Recognizer.class); grammar = context.getInstance(AlignerGrammar.class); languageModel = context.getInstance(DynamicTrigramModel.class); setTokenizer(new SimpleTokenizer()); } public List<WordResult> align(URL audioUrl, String transcript) throws IOException { return align(audioUrl, getTokenizer().expand(transcript)); } /** * Align audio to sentence transcript * * @param audioUrl audio file URL to process * @param sentenceTranscript cleaned transcript * @return List of aligned words with timings * @throws IOException if IO went wrong */ public List<WordResult> align(URL audioUrl, List<String> sentenceTranscript) throws IOException { List<String> transcript = sentenceToWords(sentenceTranscript); LongTextAligner aligner = new LongTextAligner(transcript, TUPLE_SIZE); Map<Integer, WordResult> alignedWords = new TreeMap<Integer, WordResult>(); Queue<Range> ranges = new LinkedList<Range>(); Queue<List<String>> texts = new ArrayDeque<List<String>>(); Queue<TimeFrame> timeFrames = new ArrayDeque<TimeFrame>(); ranges.offer(new Range(0, transcript.size())); texts.offer(transcript); TimeFrame totalTimeFrame = TimeFrame.INFINITE; timeFrames.offer(totalTimeFrame); long lastFrame = TimeFrame.INFINITE.getEnd(); languageModel.setText(sentenceTranscript); for (int i = 0; i < 4; ++i) { if (i == 1) { context.setLocalProperty("decoder->searchManager", "alignerSearchManager"); } while (!texts.isEmpty()) { assert texts.size() == ranges.size(); assert texts.size() == timeFrames.size(); List<String> text = texts.poll(); TimeFrame frame = timeFrames.poll(); Range range = ranges.poll(); logger.info("Aligning frame " + frame + " to text " + text + " range " + range); recognizer.allocate(); if (i >= 1) { grammar.setWords(text); } InputStream stream = audioUrl.openStream(); context.setSpeechSource(stream, frame); List<WordResult> hypothesis = new ArrayList<WordResult>(); Result result; while (null != (result = recognizer.recognize())) { logger.info("Utterance result " + result.getTimedBestResult(true)); hypothesis.addAll(result.getTimedBestResult(false)); } if (i == 0) { if (hypothesis.size() > 0) { lastFrame = hypothesis.get(hypothesis.size() - 1).getTimeFrame().getEnd(); } } List<String> words = new ArrayList<String>(); for (WordResult wr : hypothesis) { words.add(wr.getWord().getSpelling()); } int[] alignment = aligner.align(words, range); List<WordResult> results = hypothesis; logger.info("Decoding result is " + results); // dumpAlignment(transcript, alignment, results); dumpAlignmentStats(transcript, alignment, results); for (int j = 0; j < alignment.length; j++) { if (alignment[j] != -1) { alignedWords.put(alignment[j], hypothesis.get(j)); } } stream.close(); recognizer.deallocate(); } scheduleNextAlignment(transcript, alignedWords, ranges, texts, timeFrames, lastFrame); } return new ArrayList<WordResult>(alignedWords.values()); } public List<String> sentenceToWords(List<String> sentenceTranscript) { ArrayList<String> transcript = new ArrayList<String>(); for (String sentence : sentenceTranscript) { String[] words = sentence.split("\\s+"); for (String word : words) { if (word.length() > 0) transcript.add(word); } } return transcript; } private void dumpAlignmentStats(List<String> transcript, int[] alignment, List<WordResult> results) { int insertions = 0; int deletions = 0; int size = transcript.size(); int[] aid = alignment; int lastId = -1; for (int ij = 0; ij < aid.length; ++ij) { if (aid[ij] == -1) { insertions++; } else { if (aid[ij] - lastId > 1) { deletions += aid[ij] - lastId; } lastId = aid[ij]; } } if (lastId >= 0 && transcript.size() - lastId > 1) { deletions += transcript.size() - lastId; } logger.info(String.format("Size %d deletions %d insertions %d error rate %.2f", size, insertions, deletions, (insertions + deletions) / ((float) size) * 100f)); } private void scheduleNextAlignment(List<String> transcript, Map<Integer, WordResult> alignedWords, Queue<Range> ranges, Queue<List<String>> texts, Queue<TimeFrame> timeFrames, long lastFrame) { int prevKey = 0; long prevStart = 0; for (Map.Entry<Integer, WordResult> e : alignedWords.entrySet()) { if (e.getKey() - prevKey > 1) { checkedOffer(transcript, texts, timeFrames, ranges, prevKey, e.getKey() + 1, prevStart, e.getValue() .getTimeFrame().getEnd()); } prevKey = e.getKey(); prevStart = e.getValue().getTimeFrame().getStart(); } if (transcript.size() - prevKey > 1) { checkedOffer(transcript, texts, timeFrames, ranges, prevKey, transcript.size(), prevStart, lastFrame); } } public void dumpAlignment(List<String> transcript, int[] alignment, List<WordResult> results) { logger.info("Alignment"); int[] aid = alignment; int lastId = -1; for (int ij = 0; ij < aid.length; ++ij) { if (aid[ij] == -1) { logger.info(String.format("+ %s", results.get(ij))); } else { if (aid[ij] - lastId > 1) { for (String result1 : transcript.subList(lastId + 1, aid[ij])) { logger.info(String.format("- %-25s", result1)); } } else { logger.info(String.format(" %-25s", transcript.get(aid[ij]))); } lastId = aid[ij]; } } if (lastId >= 0 && transcript.size() - lastId > 1) { for (String result1 : transcript.subList(lastId + 1, transcript.size())) { logger.info(String.format("- %-25s", result1)); } } } private void checkedOffer(List<String> transcript, Queue<List<String>> texts, Queue<TimeFrame> timeFrames, Queue<Range> ranges, int start, int end, long timeStart, long timeEnd) { double wordDensity = ((double) (timeEnd - timeStart)) / (end - start); // Skip range if it's too short, average word is less than 10 // milliseconds if (wordDensity < 10.0 && (end - start) > 3) { logger.info("Skipping text range due to a high density " + transcript.subList(start, end).toString()); return; } texts.offer(transcript.subList(start, end)); timeFrames.offer(new TimeFrame(timeStart, timeEnd)); ranges.offer(new Range(start, end - 1)); } public TextTokenizer getTokenizer() { return tokenizer; } public void setTokenizer(TextTokenizer wordExpander) { this.tokenizer = wordExpander; } }