package edu.cmu.sphinx.linguist.language.ngram.trie; import java.io.BufferedInputStream; import java.io.ByteArrayOutputStream; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URL; import edu.cmu.sphinx.linguist.language.ngram.trie.NgramTrieModel.TrieUnigram; import edu.cmu.sphinx.util.Utilities; /** * Class that provides utils to load NgramTrieModel * from binary file created with sphinx_lm_convert. * Routines should be called in certain order according to format: * <ul> * <li>verifyHeader</li> * <li>readCounts</li> * <li>readQuant</li> * <li>readUnigrams</li> * <li>readTrieByteArr</li> * <li>readWords</li> * </ul> */ public class BinaryLoader { private static final String TRIE_HEADER = "Trie Language Model"; private DataInputStream inStream; public BinaryLoader(File location) throws IOException { inStream = new DataInputStream(new FileInputStream(location)); } private void loadModelData(InputStream stream) throws IOException { DataInputStream dataStream = new DataInputStream(new BufferedInputStream(stream)); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); byte[] buffer = new byte[4096]; while (true) { if (dataStream.read(buffer) < 0) break; bytes.write(buffer); } inStream = new DataInputStream(new ByteArrayInputStream(bytes.toByteArray())); } public BinaryLoader(URL location) throws IOException { loadModelData(location.openStream()); } /** * Reads header from stream and checks if it matches trie header * @throws IOException if reading from stream failed */ public void verifyHeader() throws IOException { String readHeader = readString(inStream, TRIE_HEADER.length()); if (!readHeader.equals(TRIE_HEADER)) { throw new Error("Bad binary LM file header: " + readHeader); } } /** * Reads language model order and ngram counts * @return array of counts where ordinal number is ngram order * @throws IOException if reading from stream failed */ public int[] readCounts() throws IOException { int order = readOrder(); int[] counts = new int[order]; for (int i = 0; i < counts.length; i++) { counts[i] = Utilities.readLittleEndianInt(inStream); } return counts; } /** * Reads weights quantation object from stream * @param order - max order of ngrams for this model * @return quantation object, see {@link NgramTrieQuant} * @throws IOException if reading from stream failed */ public NgramTrieQuant readQuant(int order) throws IOException { int quantTypeInt = Utilities.readLittleEndianInt(inStream); if (quantTypeInt < 0 || quantTypeInt >= NgramTrieQuant.QuantType.values().length) throw new Error("Unknown quantatization type: " + quantTypeInt); NgramTrieQuant.QuantType quantType = NgramTrieQuant.QuantType.values()[quantTypeInt]; NgramTrieQuant quant = new NgramTrieQuant(order, quantType); //reading tables for (int i = 2; i <= order; i++) { quant.setTable(readFloatArr(quant.getProbTableLen()), i, true); if (i < order) quant.setTable(readFloatArr(quant.getBackoffTableLen()), i, false); } return quant; } /** * Reads array of language model unigrams * @param count - amount of unigrams according to counts previously read * @return array of language model unigrams, see {@link NgramTrieModel.TrieUnigram} * @throws IOException if reading from stream failed */ public TrieUnigram[] readUnigrams(int count) throws IOException { TrieUnigram[] unigrams = new TrieUnigram[count + 1]; for (int i = 0; i < count + 1; i++) { unigrams[i] = new TrieUnigram(); unigrams[i].prob = Utilities.readLittleEndianFloat(inStream); unigrams[i].backoff = Utilities.readLittleEndianFloat(inStream); unigrams[i].next = Utilities.readLittleEndianInt(inStream); } return unigrams; } /** * Reads trie in form of byte array into provided array. * Size of byte array is computed from previously read language model specifications, * see {@link NgramTrie} * @param arr - byte array to read trie to * @throws IOException if reading from stream failed */ public void readTrieByteArr(byte[] arr) throws IOException { inStream.read(arr); } /** * Reads vocabulary of language model. Ordinal number of word stays for wordId. * @param unigramNum - amount of unigrams * @return array of strings - vocabulary of language model * @throws IOException of reading from stream failed */ public String[] readWords(int unigramNum) throws IOException { int len = Utilities.readLittleEndianInt(inStream); if (len <= 0) { throw new Error("Bad word string size: " + len); } String[] words = new String[unigramNum]; byte[] bytes = new byte[len]; inStream.read(bytes); int s = 0; int wordStart = 0; for (int i = 0; i < len; i++) { char c = (char) (bytes[i] & 0xFF); if (c == '\0') { // if its the end of a string, add it to the 'words' array words[s] = new String(bytes, wordStart, i - wordStart); wordStart = i + 1; s++; } } assert (s == unigramNum); return words; } /** * Should be called when model reading finished * @throws IOException if stream was corrupted */ public void close() throws IOException { inStream.close(); } /** * Reads language model max depth. * Order is stored in uint8 or in byte * @return order of language model * @throws IOException if reading from stream failed */ private int readOrder() throws IOException { return (int)inStream.readByte(); } /** * Reads float array of specified length. * Quantation tables are stored in form of float arrays, * see {@link readQuant} * @param len - length of array to read * @return array of floats that was read from stream * @throws IOException if reading from stream failed */ private float[] readFloatArr(int len) throws IOException { float[] arr = new float[len]; for (int i = 0; i < len; i++) arr[i] = Utilities.readLittleEndianFloat(inStream); return arr; } /** * Reads a string of the given length from the given DataInputStream. It is assumed that the DataInputStream * contains 8-bit chars. * * @param stream the DataInputStream to read from * @param length the number of characters in the returned string * @return a string of the given length from the given DataInputStream * @throws java.io.IOException */ private String readString(DataInputStream stream, int length) throws IOException { StringBuilder builder = new StringBuilder(); byte[] bytes = new byte[length]; stream.read(bytes); for (int i = 0; i < length; i++) { builder.append((char) bytes[i]); } return builder.toString(); } }