package edu.berkeley.nlp.lm.cache; import java.util.Arrays; import java.util.concurrent.atomic.AtomicIntegerArray; import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel.LmContextInfo; import edu.berkeley.nlp.lm.util.Annotations.OutputParameter; public final class ContextEncodedDirectMappedLmCache implements ContextEncodedLmCache { /** * */ private static final long serialVersionUID = 1L; private static int pos = 0; private static final int VAL_AND_WORD_OFFSET = pos++; private static final int CONTEXT_OFFSET = pos++; private static final int OUTPUT_CONTEXT_OFFSET = pos++; private static final int STRUCT_LENGTH = pos; private static int NUM_ORDER_BITS = 4;// good for up to 16-grams private static int NUM_OFFSETS_BITS = (Long.SIZE - NUM_ORDER_BITS); private static long ORDER_BIT_MASK = ((1L << NUM_ORDER_BITS) - 1) << (NUM_OFFSETS_BITS); private static long OFFSET_BIT_MASK = ((1L << NUM_OFFSETS_BITS) - 1); private static long WORD_MASK = ((1L << Integer.SIZE) - 1) << Integer.SIZE; private static long FLOAT_MASK = ((1L << Integer.SIZE) - 1); // for efficiency, this array fakes a struct with fields // float prob; // int word; // long contextOffset; (also contains order of context) // long outputContextOffset; (also contains order of context) private final long[] threadUnsafeArray; private final ThreadLocal<long[]> threadSafeArray; private final int cacheSize; private final boolean threadSafe; public ContextEncodedDirectMappedLmCache(final int cacheBits, final boolean threadSafe) { cacheSize = (1 << cacheBits) - 1; this.threadSafe = threadSafe; if (threadSafe) { threadUnsafeArray = null; threadSafeArray = new ThreadLocal<long[]>() { @Override protected long[] initialValue() { return allocCache(); } }; } else { threadSafeArray = null; threadUnsafeArray = allocCache(); } } /** * @return */ private long[] allocCache() { final long[] array = new long[STRUCT_LENGTH * cacheSize]; Arrays.fill(array, -1); return array; } @Override public float getCached(final long contextOffset, final int contextOrder, final int word, final int hash, @OutputParameter final LmContextInfo outputPrefix) { final long[] array = !threadSafe ? threadUnsafeArray : threadSafeArray.get(); final int cachedWordHere = getWord(hash, array); if (word >= 0 && word == cachedWordHere && getLong(hash, CONTEXT_OFFSET, array) == combine(contextOrder, contextOffset)) { final float f = getVal(hash, array); if (outputPrefix == null) return f; final long outputOrderAndOffset = getLong(hash, OUTPUT_CONTEXT_OFFSET, array); if (outputOrderAndOffset >= 0) { outputPrefix.order = orderOf(outputOrderAndOffset); outputPrefix.offset = offsetOf(outputOrderAndOffset); return f; } } return Float.NaN; } @Override public void putCached(final long contextOffset, final int contextOrder, final int word, final float score, final int hash, @OutputParameter final LmContextInfo outputPrefix) { final long[] array = !threadSafe ? threadUnsafeArray : threadSafeArray.get(); setWordAndVal(hash, word, score, array); setOutputContextOrderAndOffset(hash, outputPrefix == null ? -1 : outputPrefix.order, outputPrefix == null ? -1 : outputPrefix.offset, array); setContextOrderAndOffset(hash, contextOrder, contextOffset, array); } private static long offsetOf(final long key) { return (key & OFFSET_BIT_MASK); } /** * @param key * @return */ private static int orderOf(final long key) { return (int) ((key & ORDER_BIT_MASK) >>> (NUM_OFFSETS_BITS)); } private int getWord(final int hash, long[] array) { return (int) ((array[startOfStruct(hash) + VAL_AND_WORD_OFFSET] & WORD_MASK) >>> Integer.SIZE); } /** * @param hash * @param off * @return */ private long getLong(final int hash, final int off, long[] array) { return array[startOfStruct(hash) + off]; } private float getVal(final int hash, long[] array) { return Float.intBitsToFloat((int) array[startOfStruct(hash) + VAL_AND_WORD_OFFSET]); } private void setWordAndVal(final int hash, final int word, final float val, long[] array) { final long together = combineWordAndVal(word, val); array[startOfStruct(hash) + VAL_AND_WORD_OFFSET] = together; } private long combineWordAndVal(int word, float val) { return (((long) word) << Integer.SIZE) | (Float.floatToIntBits(val) & FLOAT_MASK); } private void setContextOrderAndOffset(final int hash, final int order, final long offset, long[] array) { final long together = combine(order, offset); setLong(hash, together, CONTEXT_OFFSET, array); } private void setOutputContextOrderAndOffset(final int hash, final int order, final long offset, long[] array) { final long together = combine(order, offset); setLong(hash, together, OUTPUT_CONTEXT_OFFSET, array); } private static long combine(final int order, final long offset) { return (((long) order) << (NUM_OFFSETS_BITS)) | offset; } /** * @param hash * @param l * @param off */ private void setLong(final int hash, final long l, final int off, long[] array) { array[startOfStruct(hash) + off] = l; } private static int startOfStruct(final int hash) { return hash * STRUCT_LENGTH; } @Override public int capacity() { return cacheSize; } }