package edu.berkeley.nlp.lm.map;
import java.io.Serializable;
import java.util.Iterator;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.bits.BitUtils;
import edu.berkeley.nlp.lm.collections.Iterators;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.MurmurHash;
/**
* Low-level hash map which stored context-encoded parent pointers in a trie.
*
* @author adampauls
*
*/
final class ImplicitWordHashMap implements Serializable, HashMap
{
/**
*
*/
private static final long serialVersionUID = 1L;
@PrintMemoryCount
final CustomWidthArray keys;
@PrintMemoryCount
private final long[] wordRanges;
private final HashNgramMap<?> ngramMap;
private long numFilled = 0;
private static final int EMPTY_KEY = 0;
private final int numWords;
private final int ngramOrder;
@SuppressWarnings("unused")
private final int totalNumWords;
private final int maxNgramOrder;
private final boolean fitsInInt;
private final int numSuffixBits;
public ImplicitWordHashMap(final LongArray numNgramsForEachWord, final long[] wordRanges, final int ngramOrder, final int maxNgramOrder,
final long numNgramsForPreviousOrder, final int totalNumWords, final HashNgramMap<?> ngramMap, final boolean fitsInInt, final boolean storeWords) {
this.ngramOrder = ngramOrder;
this.ngramMap = ngramMap;
assert ngramOrder >= 1;
this.maxNgramOrder = maxNgramOrder;
this.totalNumWords = totalNumWords;
this.numWords = (int) numNgramsForEachWord.size();
this.fitsInInt = fitsInInt;
this.wordRanges = storeWords ? null : wordRanges;
final long totalNumNgrams = setWordRanges(numNgramsForEachWord, numWords);
numSuffixBits = CustomWidthArray.numBitsNeeded(numNgramsForPreviousOrder + 1);
final int numBitsHere = numSuffixBits + (storeWords ? CustomWidthArray.numBitsNeeded(totalNumWords) : 0);
keys = new CustomWidthArray(totalNumNgrams, numBitsHere, numBitsHere + ngramMap.getValues().numValueBits(ngramOrder));
keys.fill(EMPTY_KEY, totalNumNgrams);
numFilled = 0;
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.lm.map.HashMap#put(long)
*/
@Override
public long put(final long key) {
final long i = linearSearch(key, true);
if (keys.get(i) == EMPTY_KEY) numFilled++;
setKey(i, key);
return i;
}
/**
* @param numNgramsForEachWord
* @param maxLoadFactor
* @param numWords
* @return
*/
private long setWordRanges(final LongArray numNgramsForEachWord, final long numWords) {
long currStart = 0;
for (int w = (0); w < numWords; ++w) {
if (wordRanges != null) {
setWordRangeStart(w, currStart);
currStart += ngramMap.getRangeSizeForWord(numNgramsForEachWord, w);
} else {
currStart += numNgramsForEachWord.get(w);
}
}
return wordRanges == null ? Math.round(currStart * 1.0 / ngramMap.getLoadFactor()) : currStart;
}
private void setKey(final long index, final long putKey) {
final long contextOffset = wordRanges == null ? shrinkKey(putKey) : ngramMap.contextOffsetOf(putKey);
assert contextOffset >= 0;
keys.set(index, contextOffset + 1);
}
/**
* @param word
* @param suffixIndex
* @return
*/
private final long shrinkKey(final long key) {
final int word = ngramMap.wordOf(key);
final long suffixIndex = ngramMap.contextOffsetOf(key);
return (((long) word) << (numSuffixBits)) | suffixIndex;
}
private final long expandKey(final long key) {
final int word = (int) (key >>> numSuffixBits);
final long suffixIndex = key & ((1L << numSuffixBits) - 1);
return ngramMap.combineToKey(word, suffixIndex);
}
@Override
public final long getOffset(final long key) {
return linearSearch(key, false);
}
/**
* @param key
* @param returnFirstEmptyIndex
* @return
*/
private long linearSearch(final long key, final boolean returnFirstEmptyIndex) {
final int word = ngramMap.wordOf(key);
if (word >= numWords) return -1;
final long rangeStart = wordRangeStart(word);
final long rangeEnd = wordRangeEnd(word);
final long numHashPositions = rangeEnd - rangeStart;
if (numHashPositions == 0) return -1L;
final long startIndex = hash(key, numHashPositions, rangeStart);
final long contextOffsetOf = wordRanges == null ? shrinkKey(key) : ngramMap.contextOffsetOf(key);
assert contextOffsetOf >= 0;
assert word >= 0;
assert startIndex >= rangeStart;
assert startIndex < rangeEnd;
final long index = keys.linearSearch(contextOffsetOf + 1, rangeStart, rangeEnd, startIndex, EMPTY_KEY, returnFirstEmptyIndex);
return index;
}
@Override
public long getCapacity() {
return keys.size();
}
@Override
public double getLoadFactor() {
return (double) numFilled / getCapacity();
}
private long hash(final long key, final long numHashPositions, final long startOfRange) {
long hash = BitUtils.abs(MurmurHash.hashOneLong(key, 0x9747b28c));
hash %= numHashPositions;
return hash + startOfRange;
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.lm.map.HashMap#getNextOffset(long)
*/
long getNextOffset(final long offset) {
return keys.get(offset) - 1;
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.lm.map.HashMap#getWordForContext(long)
*/
int getWordForContext(final long contextOffset) {
int binarySearch = binarySearch(contextOffset);
binarySearch = binarySearch >= 0 ? binarySearch : (-binarySearch - 2);
while (binarySearch < numWords - 1 && wordRangeStart(binarySearch) == wordRangeEnd(binarySearch))
binarySearch++;
return binarySearch;
}
private int binarySearch(final long key) {
int low = 0;
int high = numWords - 1;
while (low <= high) {
final int mid = (low + high) >>> 1;
final long midVal = wordRangeStart(mid);
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}
@Override
public long getKey(final long contextOffset) {
return wordRanges == null ? expandKey(getNextOffset(contextOffset)) : ngramMap.combineToKey(getWordForContext(contextOffset),
getNextOffset(contextOffset));
}
@Override
public boolean isEmptyKey(final long key) {
return key == EMPTY_KEY;
}
@Override
public long size() {
return numFilled;
}
@Override
public Iterable<Long> keys() {
return Iterators.able(new KeyIterator(keys));
}
public static class KeyIterator implements Iterator<Long>
{
private final CustomWidthArray keys;
public KeyIterator(final CustomWidthArray keys) {
this.keys = keys;
end = keys.size();
next = -1;
nextIndex();
}
@Override
public boolean hasNext() {
return end > 0 && next < end;
}
@Override
public Long next() {
final long nextIndex = nextIndex();
return nextIndex;
}
long nextIndex() {
final long curr = next;
do {
next++;
} while (next < end && keys != null && keys.get(next) == EMPTY_KEY);
return curr;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
private long next;
private final long end;
}
@Override
public boolean hasContexts(final int word) {
if (word >= numWords) return false;
final long rangeStart = wordRangeStart(word);
final long rangeEnd = wordRangeEnd(word);
return (rangeEnd - rangeStart > 0);
}
private final long wordRangeStart(final int w) {
return wordRanges == null ? 0 : wordRangeAt(w * maxNgramOrder + ngramOrder - 1);
}
private final long wordRangeEnd(final int w) {
return wordRanges == null || w == numWords - 1 ? getCapacity() : wordRangeAt((w + 1) * maxNgramOrder + ngramOrder - 1);
}
/**
* @param logicalIndex
* @return
*/
private long wordRangeAt(final int logicalIndex) {
if (fitsInInt) {
return logicalIndex % 2 == 0 ? BitUtils.getLowInt(wordRanges[logicalIndex / 2]) : BitUtils.getHighInt(wordRanges[logicalIndex >> 1]);
} else {
return wordRanges[logicalIndex];
}
}
private void setWordRangeStart(int w, long currStart) {
final int logicalIndex = w * maxNgramOrder + ngramOrder - 1;
if (fitsInInt) {
if (logicalIndex % 2 == 0)
wordRanges[logicalIndex / 2] = BitUtils.setLowInt(wordRanges[logicalIndex / 2], (int) currStart);
else
wordRanges[logicalIndex / 2] = BitUtils.setHighInt(wordRanges[logicalIndex / 2], (int) currStart);
} else {
wordRanges[logicalIndex] = currStart;
}
}
}