package edu.berkeley.nlp.lm.cache; import edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel; import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel; import edu.berkeley.nlp.lm.bits.BitUtils; import edu.berkeley.nlp.lm.util.MurmurHash; /** * This class wraps {@link ArrayEncodedNgramLanguageModel} with a cache. * * * @author adampauls * * @param <W> */ public class ArrayEncodedCachingLmWrapper<W> extends AbstractArrayEncodedNgramLanguageModel<W> { /** * */ private static final long serialVersionUID = 1L; private final ArrayEncodedLmCache cache; private final ArrayEncodedNgramLanguageModel<W> lm; private final int capacity; /** * To use this wrapper in a multithreaded environment, you should create one * wrapper per thread. * * @param <T> * @param lm * @return */ public static <W> ArrayEncodedCachingLmWrapper<W> wrapWithCacheNotThreadSafe(final ArrayEncodedNgramLanguageModel<W> lm) { return wrapWithCacheNotThreadSafe(lm, 18); } public static <W> ArrayEncodedCachingLmWrapper<W> wrapWithCacheNotThreadSafe(final ArrayEncodedNgramLanguageModel<W> lm, final int cacheBits) { return new ArrayEncodedCachingLmWrapper<W>(lm, false, cacheBits); } /** * * This type of caching is threadsafe and (internally) maintains a separate * cache for each thread that calls it. Note each thread has its own cache, * so if you have lots of threads, memory usage could be substantial. * * @param <W> * @param lm * @return */ public static <W> ArrayEncodedCachingLmWrapper<W> wrapWithCacheThreadSafe(final ArrayEncodedNgramLanguageModel<W> lm) { return wrapWithCacheThreadSafe(lm, 16); } public static <W> ArrayEncodedCachingLmWrapper<W> wrapWithCacheThreadSafe(final ArrayEncodedNgramLanguageModel<W> lm, final int cacheBits) { return new ArrayEncodedCachingLmWrapper<W>(lm, true, cacheBits); } private ArrayEncodedCachingLmWrapper(final ArrayEncodedNgramLanguageModel<W> lm, final boolean threadSafe, int cacheBits) { this(lm, new ArrayEncodedDirectMappedLmCache(cacheBits, lm.getLmOrder(), threadSafe)); } private ArrayEncodedCachingLmWrapper(final ArrayEncodedNgramLanguageModel<W> lm, final ArrayEncodedLmCache cache) { super(lm.getLmOrder(), lm.getWordIndexer(), Float.NaN); this.cache = cache; this.lm = lm; this.capacity = cache.capacity(); } @Override public float getLogProb(final int[] ngram, final int startPos, final int endPos) { if (endPos - startPos <= 1) return lm.getLogProb(ngram, startPos, endPos); final int hash = hash(ngram, startPos, endPos) % capacity; float f = cache.getCached(ngram, startPos, endPos, hash); if (!Float.isNaN(f)) return f; f = lm.getLogProb(ngram, startPos, endPos); cache.putCached(ngram, startPos, endPos, f, hash); return f; } private static int hash(final int[] key, final int startPos, final int endPos) { final int hash = MurmurHash.hash32(key, startPos, endPos); return BitUtils.abs(hash); } }