package edu.cmu.sphinx.linguist.language.classes; import edu.cmu.sphinx.linguist.dictionary.Word; import edu.cmu.sphinx.util.LogMath; import edu.cmu.sphinx.util.props.*; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.net.URL; import java.util.*; import java.util.logging.Level; import java.util.logging.Logger; /** * A component that knows how to map words to classes and vice versa. * * @author Tanel Alumae */ public class ClassMap implements Configurable { @S4String public final static String PROP_CLASS_DEFS_LOCATION = "classDefsLocation"; private Logger logger; private boolean allocated; private URL classDefsLocation; private LogMath logMath; /** * Maps class name to class as a Word */ private Map<String, Word> classVocabulary = new HashMap<String, Word>(); /** * Maps a word to it's class and the probability of the word being in this class */ private Map<String, ClassProbability> wordToClassProbabilities = new HashMap<String, ClassProbability>(); /** * Maps a class to a set of words that belong to this class */ private final HashMap<String, Set<String>> classToWord = new HashMap<String, Set<String>>(); public ClassMap(URL classDefsLocation) { this.logger = Logger.getLogger(getClass().getName()); this.classDefsLocation = classDefsLocation; logMath = LogMath.getLogMath(); } public ClassMap() { } /* (non-Javadoc) * @see edu.cmu.sphinx.util.props.Configurable#newProperties(edu.cmu.sphinx.util.props.PropertySheet) */ public void newProperties(PropertySheet ps) throws PropertyException { logger = ps.getLogger(); if (allocated) throw new RuntimeException("Can't change properties after allocation"); classDefsLocation = ConfigurationManagerUtils.getResource(PROP_CLASS_DEFS_LOCATION, ps); } /* * (non-Javadoc) * * @see edu.cmu.sphinx.linguist.language.ngram.LanguageModel#allocate() */ public void allocate() throws IOException { if (!allocated) { allocated = true; loadClassDefs(); } } /* * (non-Javadoc) * * @see edu.cmu.sphinx.linguist.language.ngram.LanguageModel#deallocate() */ public void deallocate() { allocated = false; wordToClassProbabilities = null; classVocabulary = null; } public ClassProbability getClassProbability(String word) { return wordToClassProbabilities.get(word); } public Word getClassAsWord(String text) { return classVocabulary.get(text); } public Set<String> getWordsInClass(String className) { return classToWord.get(className); } /** * Loads class definitions. * Class definitions should be in SRILM format: * <pre> * CLASS1 probability1 WORD1 * CLASS2 probability2 WORD2 * ... * </pre> * Probabilities should be given in linear domain. * * @throws java.io.IOException If an IO error occurs during loading Class definition resource. */ private void loadClassDefs() throws IOException { BufferedReader reader = new BufferedReader (new InputStreamReader(classDefsLocation.openStream())); String line; while ((line = reader.readLine()) != null) { StringTokenizer st = new StringTokenizer(line, " \t\n\r\f="); if (st.countTokens() != 3) { throw new IOException("corrupt word to class def: " + line + "; " + st.countTokens()); } String className = st.nextToken(); float linearProb = Float.parseFloat(st.nextToken()); String word = st.nextToken(); if (logger.isLoggable(Level.FINE)) { logger.fine(word + " --> " + className + " " + linearProb); } wordToClassProbabilities.put(word, new ClassProbability(className, logMath.linearToLog(linearProb))); classVocabulary.put(className, new Word(className, null, false)); addWordInClass(className, word); } reader.close(); checkClasses(); logger.info("Loaded word to class mappings for " + wordToClassProbabilities.size() + " words"); } /** * Checks that word probabilities in each class sum to 1. */ private void checkClasses() { Map<String, Float> sums = new HashMap<String, Float>(); for (ClassProbability cp : wordToClassProbabilities.values()) { Float sum = sums.get(cp.getClassName()); if (sum == null) { sums.put(cp.getClassName(), 0f); } else { sums.put(cp.getClassName(), (float) logMath.logToLinear(cp.getLogProbability()) + sum); } } for (Map.Entry<String, Float> entry : sums.entrySet()) { if (Math.abs(1.0 - entry.getValue()) > 0.001) { logger.warning("Word probabilities for class " + entry.getKey() + " sum to " + entry.getValue()); } } } /** * @param className Name of the class * @param word Word String */ private void addWordInClass(String className, String word) { Set<String> words = classToWord.get(className); if (words == null) { words = new HashSet<String>(); classToWord.put(className, words); } words.add(word); } }