package edu.berkeley.nlp.lm.cache;
import edu.berkeley.nlp.lm.AbstractContextEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.bits.BitUtils;
import edu.berkeley.nlp.lm.util.MurmurHash;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
/**
* This class wraps {@link ContextEncodedNgramLanguageModel} with a cache.
*
*
* @author adampauls
*
* @param <W>
*/
public class ContextEncodedCachingLmWrapper<T> extends AbstractContextEncodedNgramLanguageModel<T>
{
/**
*
*/
private static final long serialVersionUID = 1L;
private final ContextEncodedLmCache contextCache;
private final ContextEncodedNgramLanguageModel<T> lm;
private final int capacity;
/**
* This type of caching is only threadsafe if you have one cache wrapper per
* thread.
*
* @param <T>
* @param lm
* @return
*/
public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheNotThreadSafe(final ContextEncodedNgramLanguageModel<T> lm) {
return wrapWithCacheNotThreadSafe(lm, 18);
}
public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheNotThreadSafe(final ContextEncodedNgramLanguageModel<T> lm, final int cacheBits) {
return new ContextEncodedCachingLmWrapper<T>(lm, false, cacheBits);
}
/**
* This type of caching is threadsafe and (internally) maintains a separate
* cache for each thread that calls it. Note each thread has its own cache,
* so if you have lots of threads, memory usage could be substantial.
*
* @param <T>
* @param lm
* @return
*/
public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheThreadSafe(final ContextEncodedNgramLanguageModel<T> lm) {
return wrapWithCacheThreadSafe(lm, 16);
}
public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheThreadSafe(final ContextEncodedNgramLanguageModel<T> lm, final int cacheBits) {
return new ContextEncodedCachingLmWrapper<T>(lm, true, cacheBits);
}
private ContextEncodedCachingLmWrapper(final ContextEncodedNgramLanguageModel<T> lm, final boolean threadSafe, final int cacheBits) {
this(lm, new ContextEncodedDirectMappedLmCache(cacheBits, threadSafe));
}
private ContextEncodedCachingLmWrapper(final ContextEncodedNgramLanguageModel<T> lm, final ContextEncodedLmCache cache) {
super(lm.getLmOrder(), lm.getWordIndexer(), Float.NaN);
this.lm = lm;
this.contextCache = cache;
capacity = contextCache.capacity();
}
@Override
public WordIndexer<T> getWordIndexer() {
return lm.getWordIndexer();
}
@Override
public LmContextInfo getOffsetForNgram(final int[] ngram, final int startPos, final int endPos) {
return lm.getOffsetForNgram(ngram, startPos, endPos);
}
@Override
public int[] getNgramForOffset(final long contextOffset, final int contextOrder, final int word) {
return lm.getNgramForOffset(contextOffset, contextOrder, word);
}
@Override
public float getLogProb(final long contextOffset, final int contextOrder, final int word, @OutputParameter final LmContextInfo contextOutput) {
if (contextOrder < 0) return lm.getLogProb(contextOffset, contextOrder, word, contextOutput);
final int hash = hash(contextOffset, contextOrder, word) % capacity;
float f = contextCache.getCached(contextOffset, contextOrder, word, hash, contextOutput);
if (!Float.isNaN(f)) return f;
f = lm.getLogProb(contextOffset, contextOrder, word, contextOutput);
contextCache.putCached(contextOffset, contextOrder, word, f, hash, contextOutput);
return f;
}
private static int hash(final long contextOffset, final int contextOrder, final int word) {
final int hash = (int) MurmurHash.hashThreeLongs(contextOffset, contextOrder, word);
return BitUtils.abs(hash);
}
}