package edu.cmu.sphinx.linguist.language.ngram.trie; import edu.cmu.sphinx.linguist.language.ngram.trie.NgramTrieModel.TrieRange; /** * Trie structure that contains ngrams of order 2+ in reversed order. * Ngrams are stored in bit array for space efficiency. */ public class NgramTrie { private MiddleNgramSet[] middles; private LongestNgramSet longest; private NgramTrieBitarr bitArr; private int ordersNum; private int quantProbBoLen; private int quantProbLen; public NgramTrie(int[] counts, int quantProbBoLen, int quantProbLen) { int memLen = 0; int[] ngramMemSize = new int[counts.length - 1]; for (int i = 1; i <= counts.length - 1; i++) { int entryLen = requiredBits(counts[0]); if (i == counts.length - 1) { //longest ngram entryLen += quantProbLen; } else { //middle ngram entryLen += requiredBits(counts[i + 1]); entryLen += quantProbBoLen; } // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes // +8 (or +sizeof(uint64))so that reading bit array doesn't exceed bounds // Note that this waste is O(order), not O(number of ngrams). int tmpLen = ((1 + counts[i]) * entryLen + 7) / 8 + 8; ngramMemSize[i - 1] = tmpLen; memLen += tmpLen; } bitArr = new NgramTrieBitarr(memLen); this.quantProbLen = quantProbLen; this.quantProbBoLen = quantProbBoLen; middles = new MiddleNgramSet[counts.length - 2]; int[] startPtrs = new int[counts.length - 2]; int startPtr = 0; for (int i = 0; i < counts.length - 2; i++) { startPtrs[i] = startPtr; startPtr += ngramMemSize[i]; } // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (int i = counts.length - 1; i >= 2; --i) { middles[i - 2] = new MiddleNgramSet(startPtrs[i - 2], quantProbBoLen, counts[i-1], counts[0], counts[i]); } longest = new LongestNgramSet(startPtr, quantProbLen, counts[0]); ordersNum = middles.length + 1; } /** * Getter for allocated byte array to which trie is mapped * @return byte[] with ngram trie */ public byte[] getMem() { return bitArr.getArr(); } /** * Finds ngram index which corresponds to ngram with specified wordId. * Search is performed in specified range. * Fills range with ngram successors if ngram was found, makes range invalid otherwise. * @param ngramSet - set of ngrams of certain order to look in * @param wordId - word id to look for * @param range - range to look in. range contains ngram successors or is invalid after method usage. * @return ngram index that can be converted into byte offset if ngram was found, -1 otherwise */ private int findNgram(NgramSet ngramSet, int wordId, TrieRange range) { int ptr; range.begin--; if ((ptr = uniformFind(ngramSet, range, wordId)) < 0) { range.setFound(false); return -1; } //read next order ngrams for future searches if (ngramSet instanceof MiddleNgramSet) ((MiddleNgramSet)ngramSet).readNextRange(ptr, range); return ptr; } /** * Finds ngram of cerain order in specified range and reads it's backoff. * Range contains ngram successors after function execution. * If ngram is not found, range will be invalid. * @param wordId - word id to look for * @param orderMinusTwo - order of ngram minus two * @param range - range to look in, contains ngram successors after function execution * @param quant - quantation object to decode compressed backoff stored in trie * @return backoff of ngram */ public float readNgramBackoff(int wordId, int orderMinusTwo, TrieRange range, NgramTrieQuant quant) { int ptr; NgramSet ngram = getNgram(orderMinusTwo); if ((ptr = findNgram(ngram, wordId, range)) < 0) return 0.0f; return quant.readBackoff(bitArr, ngram.memPtr, ngram.getNgramWeightsOffset(ptr), orderMinusTwo); } /** * Finds ngram of cerain order in specified range and reads it's probability. * Range contains ngram successors after function execution. * If ngram is not found, range will be invalid. * @param wordId - word id to look for * @param orderMinusTwo - order of ngram minus two * @param range - range to look in, contains ngram successors after function execution * @param quant - quantation object to decode compressed probability stored in trie * @return probability of ngram */ public float readNgramProb(int wordId, int orderMinusTwo, TrieRange range, NgramTrieQuant quant) { int ptr; NgramSet ngram = getNgram(orderMinusTwo); if ((ptr = findNgram(ngram, wordId, range)) < 0) return 0.0f; return quant.readProb(bitArr, ngram.memPtr, ngram.getNgramWeightsOffset(ptr), orderMinusTwo); } /** * Calculates pivot for binary search */ private int calculatePivot(int offset, int range, int width) { return (int)(((long)offset * width) / (range + 1)); } /** * Searches ngram index for given wordId in provided range */ private int uniformFind(NgramSet ngram, TrieRange range, int wordId) { TrieRange vocabRange = new TrieRange(0, ngram.maxVocab); while (range.getWidth() > 1) { int pivot = range.begin + 1 + calculatePivot(wordId - vocabRange.begin, vocabRange.getWidth(), range.getWidth() - 1); int mid = ngram.readNgramWord(pivot); if (mid < wordId) { range.begin = pivot; vocabRange.begin = mid; } else if (mid > wordId){ range.end = pivot; vocabRange.end = mid; } else { return pivot; } } return -1; } /** * Getter for ngram set by ngram order */ private NgramSet getNgram(int orderMinusTwo) { if (orderMinusTwo == ordersNum - 1) return longest; return middles[orderMinusTwo]; } /** * Calculates minimum amount of bits to store provided int */ private int requiredBits(int maxValue) { if (maxValue == 0) return 0; int res = 1; while ((maxValue >>= 1) != 0) res++; return res; } /** * Gives access to set of ngram of certain order (trie layer) */ abstract class NgramSet { int memPtr; int wordBits; int wordMask; int totalBits; int insertIdx; int maxVocab; NgramSet(int memPtr, int maxVocab, int remainingBits) { this.maxVocab = maxVocab; this.memPtr = memPtr; wordBits = requiredBits(maxVocab); if (wordBits > 25) throw new Error("Sorry, word indices more than" + (1 << 25) + " are not implemented"); totalBits = wordBits + remainingBits; wordMask = (1 << wordBits) - 1; insertIdx = 0; } int readNgramWord(int ngramIdx) { int offset = ngramIdx * totalBits; return bitArr.readInt(memPtr, offset, wordMask); } int getNgramWeightsOffset(int ngramIdx) { return ngramIdx * totalBits + wordBits; } abstract int getQuantBits(); } /** * Implementation of NgramSet for ngrams of order [2...Max Ngram Order - 1] */ class MiddleNgramSet extends NgramSet { int nextMask; int nextOrderMemPtr; MiddleNgramSet(int memPtr, int quantBits, int entries, int maxVocab, int maxNext) { super(memPtr, maxVocab, quantBits + requiredBits(maxNext)); nextMask = (1 << requiredBits(maxNext)) - 1; if (entries + 1 >= (1 << 25) || (maxNext >= (1 << 25))) throw new Error("Sorry, current implementation doesn't support more than " + (1 << 25) + " n-grams of particular order"); } void readNextRange(int ngramIdx, TrieRange range) { int offset = ngramIdx * totalBits; offset += wordBits; offset += getQuantBits(); range.begin = bitArr.readInt(memPtr, offset, nextMask); offset += totalBits; range.end = bitArr.readInt(memPtr, offset, nextMask); } @Override int getQuantBits() { return quantProbBoLen; } } /** * Implementation of NgramSet for ngrams of maximum order */ class LongestNgramSet extends NgramSet { LongestNgramSet(int memPtr, int quantBits, int maxVocab) { super(memPtr, maxVocab, quantBits); } @Override int getQuantBits() { return quantProbLen; } } }