package edu.berkeley.nlp.lm.map;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel.LmContextInfo;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.collections.Iterators;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Annotations.PrintMemoryCount;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.values.ValueContainer;
/**
*
* @author adampauls
*
* @param <T>
*/
public final class HashNgramMap<T> extends AbstractNgramMap<T> implements ContextEncodedNgramMap<T>
{
/**
*
*/
private static final long serialVersionUID = 1L;
@PrintMemoryCount
private ExplicitWordHashMap[] explicitMaps;
@PrintMemoryCount
private final ImplicitWordHashMap[] implicitMaps;
@PrintMemoryCount
private final UnigramHashMap implicitUnigramMap;
private long[] initCapacities;
private final double maxLoadFactor;
private final boolean isExplicit;
private final boolean reversed;
private final boolean storeSuffixOffsets;
public static <T> HashNgramMap<T> createImplicitWordHashNgramMap(final ValueContainer<T> values, final ConfigOptions opts,
final LongArray[] numNgramsForEachWord, final boolean reversed) {
return new HashNgramMap<T>(values, opts, numNgramsForEachWord, reversed);
}
private HashNgramMap(final ValueContainer<T> values, final ConfigOptions opts, final LongArray[] numNgramsForEachWord, final boolean reversed) {
super(values, opts);
this.reversed = reversed;
this.maxLoadFactor = opts.hashTableLoadFactor;
this.storeSuffixOffsets = values.storeSuffixoffsets();
final int maxNgramOrder = numNgramsForEachWord.length;
explicitMaps = null;
isExplicit = false;
implicitMaps = new ImplicitWordHashMap[maxNgramOrder - 1];
final long numWords = numNgramsForEachWord[0].size();
implicitUnigramMap = new UnigramHashMap(numWords, this);
initCapacities = null;
final long maxSize = getMaximumSize(numNgramsForEachWord);
// a little ugly: store word ranges for all orders in the same array to increase cache locality
// also, if we can, store two ints per long for cache locality
final boolean fitsInInt = maxSize < Integer.MAX_VALUE;
final int logicalNumRangeEntries = (maxNgramOrder - 1) * (int) numWords;
final long[] wordRanges = new long[fitsInInt ? (logicalNumRangeEntries / 2 + logicalNumRangeEntries % 2) : logicalNumRangeEntries];
values.setMap(this);
values.setSizeAtLeast(numWords, 0);
for (int ngramOrder = 1; ngramOrder < maxNgramOrder; ++ngramOrder) {
final long numNgramsForPreviousOrder = ngramOrder == 1 ? numWords : implicitMaps[ngramOrder - 2].getCapacity();
implicitMaps[ngramOrder - 1] = new ImplicitWordHashMap(numNgramsForEachWord[ngramOrder], wordRanges, ngramOrder, maxNgramOrder - 1,
numNgramsForPreviousOrder, (int) numWords, this, fitsInInt, !opts.storeRankedProbBackoffs);
values.setSizeAtLeast(implicitMaps[ngramOrder - 1].getCapacity(), ngramOrder);
}
}
private long getMaximumSize(final LongArray[] numNgramsForEachWord) {
long max = Long.MIN_VALUE;
for (int ngramOrder = 0; ngramOrder < numNgramsForEachWord.length; ++ngramOrder) {
max = Math.max(max, getSizeOfOrder(numNgramsForEachWord[ngramOrder]));
}
return max;
}
private long getSizeOfOrder(final LongArray numNgramsForEachWord) {
long currStart = 0;
for (int w = (0); w < numNgramsForEachWord.size(); ++w) {
currStart += getRangeSizeForWord(numNgramsForEachWord, w);
}
return currStart;
}
/**
* @param numNgramsForEachWord
* @param w
* @return
*/
long getRangeSizeForWord(final LongArray numNgramsForEachWord, int w) {
final long numNgrams = numNgramsForEachWord.get(w);
final long rangeSize = numNgrams <= 3 ? numNgrams : Math.round(numNgrams * 1.0 / maxLoadFactor);
return rangeSize;
}
/**
* Note: Explicit HashNgramMap can grow beyond maxNgramOrder
*
* @param <T>
* @param values
* @param opts
* @param maxNgramOrder
* @param reversed
* @return
*/
public static <T> HashNgramMap<T> createExplicitWordHashNgramMap(final ValueContainer<T> values, final ConfigOptions opts, final int maxNgramOrder,
final boolean reversed) {
return new HashNgramMap<T>(values, opts, maxNgramOrder, reversed);
}
private HashNgramMap(final ValueContainer<T> values, final ConfigOptions opts, final int maxNgramOrder, final boolean reversed) {
super(values, opts);
this.reversed = reversed;
this.storeSuffixOffsets = values.storeSuffixoffsets();
this.maxLoadFactor = opts.hashTableLoadFactor;
implicitMaps = null;
implicitUnigramMap = null;
isExplicit = true;
explicitMaps = new ExplicitWordHashMap[maxNgramOrder];
initCapacities = new long[maxNgramOrder];
Arrays.fill(initCapacities, 100);
values.setMap(this);
}
private HashNgramMap(final ValueContainer<T> values, final ConfigOptions opts, final long[] newCapacities, final boolean reversed,
final ExplicitWordHashMap[] partialMaps) {
super(values, opts);
this.reversed = reversed;
this.storeSuffixOffsets = values.storeSuffixoffsets();
this.maxLoadFactor = opts.hashTableLoadFactor;
implicitMaps = null;
implicitUnigramMap = null;
isExplicit = true;
explicitMaps = Arrays.copyOf(partialMaps, newCapacities.length);
this.initCapacities = newCapacities;
values.setMap(this);
}
/**
* @param values
* @param newCapacities
* @param ngramOrder
* @return
*/
private ExplicitWordHashMap initMap(final long newCapacity, final int ngramOrder) {
final ExplicitWordHashMap newMap = new ExplicitWordHashMap(newCapacity);
explicitMaps[ngramOrder] = newMap;
values.setSizeAtLeast(explicitMaps[ngramOrder].getCapacity(), ngramOrder);
return newMap;
}
@Override
public long put(final int[] ngram, final int startPos, final int endPos, final T val) {
return putHelp(ngram, startPos, endPos, val, false);
}
/**
* @param ngram
* @param startPos
* @param endPos
* @param val
* @return
*/
private long putHelp(final int[] ngram, final int startPos, final int endPos, final T val, final boolean forcedNew) {
final int ngramOrder = endPos - startPos - 1;
HashMap map = getHashMapForOrder(ngramOrder);
if (!forcedNew && map instanceof ExplicitWordHashMap && map.getLoadFactor() >= maxLoadFactor) {
rehash(ngramOrder, map.getCapacity() * 3 / 2, 1);
map = getHashMapForOrder(ngramOrder);
}
final long key = getKey(ngram, startPos, endPos);
if (key < 0) return -1L;
return putHelp(map, ngram, startPos, endPos, key, val, forcedNew);
}
/**
* @param ngramOrder
* @return
*/
private HashMap getHashMapForOrder(final int ngramOrder) {
HashMap map = getMap(ngramOrder);
if (map == null) {
final long newCapacity = initCapacities[ngramOrder];
assert newCapacity >= 0 : "Bad capacity " + newCapacity + " for order " + ngramOrder;
map = initMap(newCapacity, ngramOrder);
}
return map;
}
/**
* Warning: does not rehash if load factor is exceeded, must call
* rehashIfNecessary explicitly. This is so that the offsets returned remain
* valid. Basically, you should not use this function unless you really know
* what you're doing.
*
* @param ngram
* @param startPos
* @param endPos
* @param contextOffset
* @param val
* @return
*/
public long putWithOffset(final int[] ngram, final int startPos, final int endPos, final long contextOffset, final T val) {
final int ngramOrder = endPos - startPos - 1;
final long key = combineToKey(ngram[endPos - 1], contextOffset);
final HashMap map = getHashMapForOrder(ngramOrder);
return putHelp(map, ngram, startPos, endPos, key, val, false);
}
/**
* Warning: does not rehash if load factor is exceeded, must call
* rehashIfNecessary explicitly. This is so that the offsets returned remain
* valid. Basically, you should not use this function unless you really know
* what you're doing.
*
* @param ngram
* @param startPos
* @param endPos
* @param contextOffset
* @param val
* @return
*/
public long putWithOffsetAndSuffix(final int[] ngram, final int startPos, final int endPos, final long contextOffset, final long suffixOffset, final T val) {
final int ngramOrder = endPos - startPos - 1;
final long key = combineToKey(ngram[endPos - 1], contextOffset);
final HashMap map = getHashMapForOrder(ngramOrder);
return putHelpWithSuffixIndex(map, ngram, startPos, endPos, key, val, false, suffixOffset);
}
public void rehashIfNecessary(int num) {
if (explicitMaps == null) return;
for (int ngramOrder = 0; ngramOrder < explicitMaps.length; ++ngramOrder) {
if (explicitMaps[ngramOrder] == null)
initCapacities[ngramOrder] = Math.max(100, num) * 3/2;
else if (explicitMaps[ngramOrder].getLoadFactor(num) >= maxLoadFactor) {
rehash(ngramOrder, (explicitMaps[ngramOrder].getCapacity() + num) * 3 / 2, num);
return;
}
}
}
private long putHelp(final HashMap map, final int[] ngram, final int startPos, final int endPos, final long key, final T val, final boolean forcedNew) {
final long suffixIndex = storeSuffixOffsets ? getSuffixOffset(ngram, startPos, endPos) : -1L;
return putHelpWithSuffixIndex(map, ngram, startPos, endPos, key, val, forcedNew, suffixIndex);
}
/**
* @param map
* @param ngram
* @param startPos
* @param endPos
* @param key
* @param val
* @param forcedNew
* @param suffixIndex
* @return
*/
private long putHelpWithSuffixIndex(final HashMap map, final int[] ngram, final int startPos, final int endPos, final long key, final T val,
final boolean forcedNew, final long suffixIndex) {
final int ngramOrder = endPos - startPos - 1;
final long oldSize = map.size();
final long index = map.put(key);
final boolean addWorked = values.add(ngram, startPos, endPos, ngramOrder, index, contextOffsetOf(key), wordOf(key), val, suffixIndex,
map.size() > oldSize || forcedNew);
if (!addWorked) return -1;
return index;
}
@Override
public long getValueAndOffset(final long contextOffset, final int contextOrder, final int word, @OutputParameter final T outputVal) {
return getOffsetForContextEncoding(contextOffset, contextOrder, word, outputVal);
}
@Override
public long getOffset(final long contextOffset, final int contextOrder, final int word) {
return getOffsetForContextEncoding(contextOffset, contextOrder, word, null);
}
@Override
public int[] getNgramFromContextEncoding(final long contextOffset, final int contextOrder, final int word) {
final int[] ret = new int[Math.max(1, contextOrder + 2)];
getNgramFromContextEncodingHelp(contextOffset, contextOrder, word, ret);
return ret;
}
/**
* @param contextOffset
* @param contextOrder
* @param word
* @param scratch
* @return
*/
private void getNgramFromContextEncodingHelp(final long contextOffset, final int contextOrder, final int word, final int[] scratch) {
if (contextOrder < 0) {
scratch[0] = word;
} else {
long contextOffset_ = contextOffset;
int word_ = word;
scratch[reversed ? 0 : (scratch.length - 1)] = word_;
for (int i = 0; i <= contextOrder; ++i) {
final int ngramOrder = contextOrder - i;
final long key = getKey(contextOffset_, ngramOrder);
contextOffset_ = contextOffsetOf(key);
word_ = wordOf(key);
scratch[reversed ? (i + 1) : (scratch.length - i - 2)] = word_;
}
}
}
public int getNextWord(final long offset, final int ngramOrder) {
return wordOf(getKey(offset, ngramOrder));
}
public long getNextContextOffset(final long offset, final int ngramOrder) {
return contextOffsetOf(getKey(offset, ngramOrder));
}
/**
* Gets the "key" (word + context offset) for a given offset
*
* @param contextOffset_
* @param ngramOrder
* @return
*/
private long getKey(final long offset, final int ngramOrder) {
return getMap(ngramOrder).getKey(offset);
}
public int getFirstWordForOffset(final long offset, final int ngramOrder) {
final long key = getMap(ngramOrder).getKey(offset);
if (ngramOrder == 0)
return wordOf(key);
else
return getFirstWordForOffset(contextOffsetOf(key), ngramOrder - 1);
}
public int getLastWordForOffset(final long offset, final int ngramOrder) {
final long key = getMap(ngramOrder).getKey(offset);
return wordOf(key);
}
public int[] getNgramForOffset(final long offset, final int ngramOrder) {
final int[] ret = new int[ngramOrder + 1];
return getNgramForOffset(offset, ngramOrder, ret);
}
public int[] getNgramForOffset(final long offset, final int ngramOrder, final int[] ret) {
long offset_ = offset;
for (int i = 0; i <= ngramOrder; ++i) {
final long key = getMap(ngramOrder - i).getKey(offset_);
offset_ = contextOffsetOf(key);
final int word_ = wordOf(key);
ret[reversed ? (i) : (ngramOrder - i)] = word_;
}
return ret;
}
/**
* @param contextOffset_
* @param contextOrder
* @param word
* @param logFailure
* @return
*/
private long getOffsetForContextEncoding(final long contextOffset_, final int contextOrder, final int word, @OutputParameter final T outputVal) {
if (word < 0) return -1;
final int ngramOrder = contextOrder + 1;
final long contextOffset = contextOffset_ >= 0 ? contextOffset_ : 0;
final long key = combineToKey(word, contextOffset);
final long offset = getOffsetHelpFromMap(ngramOrder, key);
if (outputVal != null && offset >= 0) {
values.getFromOffset(offset, ngramOrder, outputVal);
}
return offset;
}
private long getOffsetHelpFromMap(int ngramOrder, long key) {
if (isExplicit) { return (ngramOrder >= explicitMaps.length || explicitMaps[ngramOrder] == null) ? -1 : explicitMaps[ngramOrder].getOffset(key); }
return ngramOrder == 0 ? implicitUnigramMap.getOffset(key) : implicitMaps[ngramOrder - 1].getOffset(key);
}
private void rehash(final int changedNgramOrder, final long newCapacity, final int numAdding) {
assert isExplicit;
final long[] newCapacities = new long[explicitMaps.length];
Arrays.fill(newCapacities, -1L);
assert changedNgramOrder >= 0;
for (int ngramOrder = 0; ngramOrder < explicitMaps.length; ++ngramOrder) {
if (explicitMaps[ngramOrder] == null) break;
if (ngramOrder < changedNgramOrder) {
newCapacities[ngramOrder] = explicitMaps[ngramOrder].getCapacity();
} else if (ngramOrder == changedNgramOrder) {
newCapacities[ngramOrder] = newCapacity;
} else {
newCapacities[ngramOrder] = explicitMaps[ngramOrder].getLoadFactor(numAdding) >= maxLoadFactor / 2 ? ((explicitMaps[ngramOrder].getCapacity() + numAdding) * 3 / 2)
: explicitMaps[ngramOrder].getCapacity();
}
assert newCapacities[ngramOrder] >= 0 : "Bad capacity " + newCapacities[ngramOrder];
}
final ValueContainer<T> newValues = values.createFreshValues(newCapacities);
final HashNgramMap<T> newMap = new HashNgramMap<T>(newValues, opts, newCapacities, reversed, Arrays.copyOf(explicitMaps, changedNgramOrder));
for (int ngramOrder = 0; ngramOrder < explicitMaps.length; ++ngramOrder) {
final ExplicitWordHashMap currHashMap = explicitMaps[ngramOrder];
if (currHashMap == null) continue;
final ExplicitWordHashMap newHashMap = (ExplicitWordHashMap) newMap.getHashMapForOrder(ngramOrder);
final T val = values.getScratchValue();
final int[] scratchArray = new int[ngramOrder + 1];
for (long actualIndex = 0; actualIndex < currHashMap.getCapacity(); ++actualIndex) {
final long key = currHashMap.getKey(actualIndex);
if (currHashMap.isEmptyKey(key)) continue;
getNgramFromContextEncodingHelp(contextOffsetOf(key), ngramOrder - 1, wordOf(key), scratchArray);
final long newKey = newMap.getKey(scratchArray, 0, scratchArray.length);
assert newKey >= 0 : "Failure for old n-gram " + Arrays.toString(scratchArray) + " :: " + newKey;
final long index = newHashMap.put(newKey);
assert index >= 0;
final long suffixIndex = storeSuffixOffsets ? newMap.getSuffixOffset(scratchArray, 0, scratchArray.length) : -1L;
assert !storeSuffixOffsets || suffixIndex >= 0 : "Could not find suffix offset for " + Arrays.toString(scratchArray);
values.getFromOffset(actualIndex, ngramOrder, val);
final boolean addWorked = newMap.values.add(scratchArray, 0, scratchArray.length, ngramOrder, index, contextOffsetOf(newKey), wordOf(newKey),
val, suffixIndex, true);
assert addWorked;
}
values.clearStorageForOrder(ngramOrder);
}
System.arraycopy(newMap.explicitMaps, 0, explicitMaps, 0, newMap.explicitMaps.length);
values.setFromOtherValues(newValues);
values.setMap(this);
}
/**
* @param ngram
* @param startPos
* @param endPos
* @return
*/
private long getOffsetFromRawNgram(final int[] ngram, final int startPos, final int endPos) {
if (containsOutOfVocab(ngram, startPos, endPos)) return -1;
final int ngramOrder = endPos - startPos - 1;
if (ngramOrder >= getMaxNgramOrder()) return -1;
final long key = getKey(ngram, startPos, endPos);
if (key < 0) return -1;
final HashMap currMap = getMap(ngramOrder);
if (currMap == null) return -1;
final long index = currMap.getOffset(key);
return index;
}
@Override
public LmContextInfo getOffsetForNgram(final int[] ngram, final int startPos, final int endPos) {
final LmContextInfo lmContextInfo = new LmContextInfo();
for (int start = endPos - 1; start >= startPos; --start) {
final long offset = getOffsetFromRawNgram(ngram, start, endPos);
if (offset < 0) break;
lmContextInfo.offset = offset;
lmContextInfo.order = endPos - start - 1;
}
return lmContextInfo;
}
/**
* Like {@link #getOffsetForNgram(int[], int, int)}, but assumes that the
* full n-gram is in the map (i.e. does not back off to the largest suffix
* which is in the model).
*
* @param ngram
* @param startPos
* @param endPos
* @return
*/
public long getOffsetForNgramInModel(final int[] ngram, final int startPos, final int endPos) {
return getOffsetFromRawNgram(ngram, startPos, endPos);
}
@Override
public void handleNgramsFinished(final int justFinishedOrder) {
}
@Override
public void initWithLengths(final List<Long> numNGrams) {
}
@Override
public void trim() {
for (int ngramOrder = 0; ngramOrder < getMaxNgramOrder(); ++ngramOrder) {
final HashMap currMap = getMap(ngramOrder);
if (currMap == null) break;
values.trimAfterNgram(ngramOrder, currMap.getCapacity());
Logger.logss("Load factor for " + (ngramOrder + 1) + ": " + currMap.getLoadFactor());
}
values.trim();
}
/**
* @param ngram
* @param endPos
* @return
*/
private long getSuffixOffset(final int[] ngram, final int startPos, final int endPos) {
if (endPos - startPos == 1) return 0;
final long offset = getOffsetFromRawNgram(ngram, reversed ? startPos : (startPos + 1), reversed ? (endPos - 1) : endPos);
return offset;
}
/**
* Gets the offset of the context for an n-gram (represented by offset)
*
* @param offset
* @return
*/
public long getPrefixOffset(final long offset, final int ngramOrder) {
if (ngramOrder == 0) return -1;
return contextOffsetOf(getKey(offset, ngramOrder));
}
private long getKey(final int[] ngram, final int startPos, final int endPos) {
long contextOffset = 0;
for (int ngramOrder = 0; ngramOrder < endPos - startPos - 1; ++ngramOrder) {
final int currNgramPos = reversed ? (endPos - ngramOrder - 1) : (startPos + ngramOrder);
contextOffset = getOffsetForContextEncoding(contextOffset, ngramOrder - 1, ngram[currNgramPos], null);
if (contextOffset == -1L) { return -1; }
}
return combineToKey(headWord(ngram, startPos, endPos), contextOffset);
}
private int headWord(final int[] ngram, final int startPos, final int endPos) {
return reversed ? ngram[startPos] : ngram[endPos - 1];
}
@Override
public int getMaxNgramOrder() {
return explicitMaps == null ? (implicitMaps.length + 1) : explicitMaps.length;
}
@Override
public long getNumNgrams(final int ngramOrder) {
return getMap(ngramOrder).size();
}
@Override
public Iterable<Entry<T>> getNgramsForOrder(final int ngramOrder) {
final HashMap map = getMap(ngramOrder);
if (map == null)
return Collections.emptyList();
else
return Iterators.able(new Iterators.Transform<Long, Entry<T>>(map.keys().iterator())
{
@Override
protected Entry<T> transform(final Long next) {
final long offset = next;
final T val = values.getScratchValue();
values.getFromOffset(offset, ngramOrder, val);
return new Entry<T>(getNgramForOffset(offset, ngramOrder), val);
}
});
}
public Iterable<Long> getNgramOffsetsForOrder(final int ngramOrder) {
final HashMap map = getMap(ngramOrder);
if (map == null)
return Collections.emptyList();
else
return map.keys();
}
private HashMap getMap(int ngramOrder) {
if (explicitMaps == null) { return ngramOrder == 0 ? implicitUnigramMap : implicitMaps[ngramOrder - 1]; }
if (ngramOrder >= explicitMaps.length) {
int oldLength = explicitMaps.length;
explicitMaps = Arrays.copyOf(explicitMaps, explicitMaps.length * 2);
initCapacities = Arrays.copyOf(initCapacities, initCapacities.length * 2);
Arrays.fill(initCapacities, oldLength, initCapacities.length, 100);
}
return explicitMaps[ngramOrder];
}
public boolean isReversed() {
return reversed;
}
@Override
public boolean wordHasBigrams(final int word) {
return getMaxNgramOrder() < 2 ? false : (explicitMaps == null ? implicitMaps[0].hasContexts(word) : explicitMaps[1].hasContexts(word));
}
@Override
public boolean contains(final int[] ngram, final int startPos, final int endPos) {
return getOffsetFromRawNgram(ngram, startPos, endPos) >= 0;
}
@Override
public T get(int[] ngram, int startPos, int endPos) {
final long offset = getOffsetFromRawNgram(ngram, startPos, endPos);
if (offset < 0) {
return null;
} else {
final T val = values.getScratchValue();
values.getFromOffset(offset, endPos - startPos - 1, val);
return val;
}
}
public long getTotalSize() {
long ret = 0L;
for (int ngramOrder = 0; ngramOrder < getMaxNgramOrder(); ++ngramOrder) {
final HashMap currMap = getMap(ngramOrder);
if (currMap == null) break;
ret += currMap.size();
}
return ret;
}
@Override
public CustomWidthArray getValueStoringArray(final int ngramOrder) {
return (ngramOrder == 0 || isExplicit) ? null : implicitMaps[ngramOrder - 1].keys;
}
@Override
public void clearStorage() {
if (implicitMaps != null) {
for (int i = 0; i < implicitMaps.length; ++i) {
implicitMaps[i] = null;
}
}
if (explicitMaps != null) {
for (int i = 0; i < explicitMaps.length; ++i) {
explicitMaps[i] = null;
}
}
}
double getLoadFactor() {
return maxLoadFactor;
}
}