package edu.berkeley.nlp.lm.map;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.array.CustomWidthArray;
import edu.berkeley.nlp.lm.array.LongArray;
import edu.berkeley.nlp.lm.bits.BitList;
import edu.berkeley.nlp.lm.bits.BitStream;
import edu.berkeley.nlp.lm.bits.VariableLengthBitCompressor;
import edu.berkeley.nlp.lm.util.Annotations.OutputParameter;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.values.CompressibleValueContainer;
public class CompressedNgramMap<T> extends AbstractNgramMap<T> implements Serializable
{
/**
*
*/
private static final long serialVersionUID = 1L;
private final int compressedBlockSize;
private static final int OFFSET_RADIX = 33;
private static final int WORD_RADIX = 2;
private final VariableLengthBitCompressor offsetCoder;
private final VariableLengthBitCompressor wordCoder;
private final VariableLengthBitCompressor suffixCoder;
private double totalKeyBitsFinal = 0;
private double totalValueBitsFinal = 0;
private double totalBitsFinal = 0;
private double totalSizeFinal = 0;
private final int offsetDeltaRadix;
private final CompressedMap[] maps;
private final boolean reverseTrie = true;
private final long[] numNgramsForEachOrder;
public CompressedNgramMap(final CompressibleValueContainer<T> values, final long[] numNgramsForEachOrder, final ConfigOptions opts) {
super(values, opts);
offsetCoder = new VariableLengthBitCompressor(OFFSET_RADIX);
wordCoder = new VariableLengthBitCompressor(WORD_RADIX);
this.offsetDeltaRadix = opts.offsetDeltaRadix;
suffixCoder = new VariableLengthBitCompressor(offsetDeltaRadix);
this.compressedBlockSize = opts.compressedBlockSize;
this.numNgramsForEachOrder = numNgramsForEachOrder;
this.maps = new CompressedMap[numNgramsForEachOrder.length];
values.setMap(this);
}
@Override
public long getValueAndOffset(final long contextOffset, final int contextNgramOrder, final int word, @OutputParameter final T outputVal) {
if (word < 0) return -1L;
final long hash = combineToKey(word, contextOffset);
final int ngramOrder = contextNgramOrder + 1;
final LongArray compressedKeys = (maps[ngramOrder]).compressedKeys;
final long currIndex = decompressSearch(compressedKeys, hash, ngramOrder, outputVal);
return currIndex;
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.mt.lm.NgramMap#add(java.util.List, T)
*/
@Override
public long put(final int[] ngram, final int startPos, final int endPos, final T val) {
final int ngramOrder = endPos - startPos - 1;
final int word = reverseTrie ? ngram[startPos] : ngram[endPos - 1];
final long contextOffset = reverseTrie ? getContextOffset(ngram, startPos + 1, endPos, null) : getContextOffset(ngram, startPos, endPos - 1, null);
if (contextOffset < 0) return -1;
CompressedMap map = maps[ngramOrder];
if (map == null) {
map = maps[ngramOrder] = new CompressedMap();
final long l = numNgramsForEachOrder[ngramOrder];
maps[ngramOrder].init(l);
values.setSizeAtLeast(l, ngramOrder);
}
final long oldSize = map.size();
final long newOffset = map.add(combineToKey(word, contextOffset));
final boolean addWorked = values.add(ngram, startPos, endPos, ngramOrder, map.size() - 1, contextOffset, word, val, -1, map.size() == oldSize);
if (!addWorked) return -1;
return newOffset;
}
private long getContextOffset(final int[] ngram, final int startPos, final int endPos, T val) {
if (endPos == startPos) return 0;
long hasValueSuffixIndex = 0;
if (endPos > startPos) {
long lastSuffix = 0L;
for (int ngramOrder = 0; ngramOrder < endPos - startPos; ++ngramOrder) {
final int firstWord = reverseTrie ? ngram[endPos - ngramOrder - 1] : ngram[startPos + ngramOrder];
final long key = combineToKey(firstWord, lastSuffix);
if (maps[ngramOrder] == null) return -1;
final LongArray compressedKeys = (maps[ngramOrder]).compressedKeys;
final long currIndex = decompressSearch(compressedKeys, key, ngramOrder, val);
if (currIndex < 0) return -1;
lastSuffix = currIndex;
}
hasValueSuffixIndex = lastSuffix;
}
return hasValueSuffixIndex;
}
/*
* (non-Javadoc)
*
* @see edu.berkeley.nlp.mt.lm.NgramMap#handleNgramsFinished(int)
*/
@Override
public void handleNgramsFinished(final int justFinishedOrder) {
final CompressedMap compressedMap = maps[justFinishedOrder - 1];
if (compressedMap != null) {
final LongArray currKeys = compressedMap.getUncompressedKeys();
final long currSize = currKeys.size();
sort(currKeys, 0, currSize - 1, justFinishedOrder - 1);
compressedMap.trim();
values.trimAfterNgram(justFinishedOrder - 1, currSize);
compress(justFinishedOrder - 1);
}
}
protected static int compareLongsRaw(final long a, final long b) {
assert a >= 0;
assert b >= 0;
if (a > b) return +1;
if (a < b) return -1;
if (a == b) return 0;
throw new RuntimeException();
}
private void compress(final int ngramOrder) {
if (ngramOrder > 0) {
(maps[ngramOrder]).compressedKeys = compress(maps[ngramOrder].getUncompressedKeys(), maps[ngramOrder].size(), ngramOrder);
((CompressibleValueContainer<T>) values).clearStorageAfterCompression(ngramOrder);
maps[ngramOrder].clearUncompressedKeys();
}
}
private LongArray compress(final LongArray uncompressed, final long uncompressedSize, final int ngramOrder) {
Logger.startTrack("Compressing");
final LongArray compressedLongArray = LongArray.StaticMethods.newLongArray(Long.MAX_VALUE, uncompressedSize >>> 2);
long uncompressedPos = 0;
long totalNumKeyBits = 0;
long totalNumValueBits = 0;
long currBlock = 0;
final CompressibleValueContainer<T> compressibleValues = (CompressibleValueContainer<T>) values;
while (uncompressedPos < uncompressedSize) {
final BitList currBlockBits = new BitList();
final long firstKey = uncompressed.get(uncompressedPos);
if (currBlock++ % 1000 == 0) Logger.logs("On block " + currBlock + " starting at pos " + uncompressedPos);
currBlockBits.addLong(firstKey);
final BitList offsetBits = offsetCoder.compress(uncompressedPos);
final BitList firstValueBits = compressibleValues.getCompressed(uncompressedPos, ngramOrder);
BitList headerBits = new BitList();
BitList bodyBits = new BitList();
long numKeyBits = 0;
long numValueBits = 0;
long currUncompressedPos = -1;
// try compression assuming all words are the same (wordBitOn = false), and if that fails,
// roll back and try with wordBitOn = true
OUTER: for (boolean wordBitOn = false, done = false; !done; wordBitOn = true) {
numKeyBits = 0;
numValueBits = 0;
long lastFirstWord = wordOf(firstKey);
long lastSuffixPart = contextOffsetOf(firstKey);
headerBits = makeHeader(offsetBits, firstValueBits, wordBitOn);
bodyBits = new BitList();
final BitList currBits = new BitList();
for (currUncompressedPos = uncompressedPos + 1; currUncompressedPos < uncompressedSize; ++currUncompressedPos) {
final long currKey = uncompressed.get(currUncompressedPos);
final long currFirstWord = wordOf(currKey);
final long currSuffixPart = contextOffsetOf(currKey);
final long wordDelta = currFirstWord - lastFirstWord;
final long suffixDelta = currSuffixPart - lastSuffixPart;
currBits.clear();
if (wordDelta > 0 && !wordBitOn) continue OUTER;
if (wordBitOn) {
final BitList keyBits = wordCoder.compress(wordDelta);
currBits.addAll(keyBits);
if (wordDelta > 0) {
final BitList suffixBits = suffixCoder.compress(currSuffixPart);
currBits.addAll(suffixBits);
} else {
final BitList suffixBits = suffixCoder.compress(suffixDelta);
currBits.addAll(suffixBits);
}
} else {
final BitList suffixBits = suffixCoder.compress(suffixDelta);
currBits.addAll(suffixBits);
}
numKeyBits += currBits.size();
lastFirstWord = currFirstWord;
numValueBits += compressValue(ngramOrder, currUncompressedPos, currBits);
lastSuffixPart = currSuffixPart;
if (blockFull(currBlockBits, bodyBits, headerBits, currBits)) {
break;
}
bodyBits.addAll(currBits);
}
done = true;
}
uncompressedPos = currUncompressedPos;
totalNumKeyBits += numKeyBits;
totalNumValueBits += numValueBits;
final int bitLength = bodyBits.size() + headerBits.size();
assert bitLength <= Short.MAX_VALUE;
currBlockBits.addShort((short) bitLength);
currBlockBits.addAll(headerBits);
currBlockBits.addAll(bodyBits);
assert currBlockBits.size() < Long.SIZE * compressedBlockSize;
writeBlockToArray(currBlockBits, compressedLongArray);
}
compressedLongArray.trim();
logCompressionInfo(uncompressedSize, compressedLongArray, totalNumKeyBits, totalNumValueBits);
Logger.endTrack();
return compressedLongArray;
}
/**
* @param blockBits
* @param array
*/
private void writeBlockToArray(final BitList blockBits, final LongArray array) {
long curr = 0L;
for (int i = 0; i <= Long.SIZE * compressedBlockSize; ++i) {
if (i % Long.SIZE == 0 && i > 0) {
array.add(curr);
curr = 0;
}
curr = (curr << 1) | ((i >= blockBits.size() || !blockBits.get(i)) ? 0 : 1);
}
assert array.size() % compressedBlockSize == 0;
}
/**
* @param uncompressedSize
* @param compressedLongArray
* @param keyBits
* @param valueBits
*/
private void logCompressionInfo(final long uncompressedSize, final LongArray compressedLongArray, final long keyBits, final long valueBits) {
final double keyAvg = (double) keyBits / uncompressedSize;
Logger.logss("Key bits " + keyAvg);
final double valueAvg = (double) valueBits / uncompressedSize;
Logger.logss("Value bits " + valueAvg);
final double avg = 64 * (double) compressedLongArray.size() / uncompressedSize;
Logger.logss("Compressed bits " + avg);
totalKeyBitsFinal += keyBits;
totalValueBitsFinal += valueBits;
totalBitsFinal += compressedLongArray.size();
totalSizeFinal += uncompressedSize;
Logger.logss("Total key bits " + totalKeyBitsFinal / totalSizeFinal);
Logger.logss("Total value bits " + totalValueBitsFinal / totalSizeFinal);
Logger.logss("Total bits " + 64.0 * totalBitsFinal / totalSizeFinal);
}
/**
* @param currBits
* @param restBits
* @param headerBits
* @param newBits
* @return
*/
private boolean blockFull(final BitList currBits, final BitList restBits, final BitList headerBits, final BitList newBits) {
final int numTotalBitsSize = Short.SIZE;
final int lengthSoFar = currBits.size() + numTotalBitsSize + headerBits.size() + restBits.size() + newBits.size();
return lengthSoFar >= Long.SIZE * compressedBlockSize;
}
/**
* @param ngramOrder
* @param valBits
* @param currPos
* @param newBits
* @return
*/
private long compressValue(final int ngramOrder, final long currPos, final BitList newBits) {
final BitList valueBits = ((CompressibleValueContainer<T>) values).getCompressed(currPos, ngramOrder);
newBits.addAll(valueBits);
return valueBits.size();
}
/**
* @param offsetBits
* @param firstValueBits
* @param wordBitOn
* @return
*/
private BitList makeHeader(final BitList offsetBits, final BitList firstValueBits, final boolean wordBitOn) {
BitList headerBits;
headerBits = new BitList();
headerBits.addAll(offsetBits);
headerBits.add(wordBitOn);
headerBits.addAll(firstValueBits);
return headerBits;
}
/**
* searchOffset >= 0 means we are looking for a specific offset and ignore
* searchKey if searchOffset >= 0, we return the key, else we return the
* offset for searchKey
*
* @param compressed
* @param pos
* @param searchKey
* @param ngramOrder
* @param outputVal
* @param searchOffset
* @return
*/
private long decompressLinearSearch(final LongArray compressed, final long pos, final long searchKey, final int ngramOrder, final T outputVal,
final long searchOffset) {
final long firstKey = compressed.get(pos);
final BitStream bits = getCompressedBits(compressed, pos + 1);
final long offset = offsetCoder.decompress(bits);
final boolean wordBitOn = bits.nextBit();
int currWord = wordOf(firstKey);
long currSuffix = contextOffsetOf(firstKey);
final boolean foundKeyFirst = searchOffset >= 0 ? searchOffset == offset : firstKey == searchKey;
final CompressibleValueContainer<T> compressibleValues = (CompressibleValueContainer<T>) values;
compressibleValues.decompress(bits, ngramOrder, !foundKeyFirst, outputVal);
if (foundKeyFirst) return searchOffset >= 0 ? firstKey : offset;
long currKey = -1;
for (int k = 1; !bits.finished(); ++k) {
int newWord = -1;
long nextSuffix = -1;
if (wordBitOn) {
final int wordDelta = (int) wordCoder.decompress(bits);
final boolean wordDeltaIsZero = wordDelta == 0;
final long suffixDelta = suffixCoder.decompress(bits);
newWord = currWord + wordDelta;
nextSuffix = wordDeltaIsZero ? (currSuffix + suffixDelta) : suffixDelta;
} else {
final long suffixDelta = suffixCoder.decompress(bits);
newWord = currWord;
nextSuffix = (currSuffix + suffixDelta);
}
currKey = combineToKey(newWord, nextSuffix);
currWord = newWord;
currSuffix = nextSuffix;
final long currOffset = offset + k;
final boolean foundKey = searchOffset >= 0 ? searchOffset == currOffset : currKey == searchKey;
compressibleValues.decompress(bits, ngramOrder, !foundKey, outputVal);
if (foundKey) { return searchOffset >= 0 ? currKey : currOffset; }
if (searchOffset >= 0) {
if (currOffset > searchOffset) return -1;
} else if (currKey > searchKey) return -1;
}
return -1;
}
/**
* @param compressed
* @param pos
* @return
*/
private BitStream getCompressedBits(final LongArray compressed, final long pos) {
final short bitLength = readShort(compressed.get(pos));
final BitStream bits = new BitStream(compressed, pos, Short.SIZE, bitLength);
return bits;
}
private short readShort(final long l) {
return (short) (l >>> (Long.SIZE - Short.SIZE));
}
private long decompressSearch(final LongArray compressed, final long searchKey, final int ngramOrder, final T outputVal) {
return decompressSearch(compressed, searchKey, ngramOrder, outputVal, -1);
}
private long decompressSearch(final LongArray compressed, final long searchKey, final int ngramOrder, final T outputVal, final long searchOffset) {
if (ngramOrder == 0) {
final boolean lookingForOffset = searchKey >= 0;
final int word = lookingForOffset ? wordOf(searchKey) : (int) searchOffset;
if (word < 0 || word >= maps[0].size()) return -1;
if (outputVal != null) values.getFromOffset(word, 0, outputVal);
return lookingForOffset ? word : combineToKey(word, 0);
} else {
if (compressed == null) return -1;
final long fromIndex = 0;
final long toIndex = ((compressed.size() / compressedBlockSize) - 1);
final long low = binarySearchBlocks(compressed, compressed.size(), searchKey, fromIndex, toIndex, searchOffset);
if (low < 0) return -1;
final long index = decompressLinearSearch(compressed, low, searchKey, ngramOrder, outputVal, searchOffset);
return index;
}
}
/**
* @param compressed
* @param searchKey
* @return
*/
private long binarySearchBlocks(final LongArray compressed, final long size, final long searchKey, final long low_, final long high_,
final long searchOffset) {
final long toFind = searchOffset >= 0 ? searchOffset : searchKey;
long low = low_;
long high = high_;
assert size % compressedBlockSize == 0;
while (low <= high) {
final long mid = (low + high) >>> 1;
final long currPos = mid * compressedBlockSize;
final long midVal = searchOffset >= 0 ? offsetCoder.decompress(getCompressedBits(compressed, currPos + 1)) : compressed.get(currPos);
final int compare = compareLongsRaw(midVal, toFind);
if (compare < 0) //midVal < key
low = mid + 1;
else if (compare > 0) // midVal > key
high = mid - 1;
else {
low = mid + 1;
break;// key found
}
}
if (low <= 0) return -1;
final long i = (low - 1) * compressedBlockSize;
return i;
}
protected void sort(final LongArray array, final long left0, final long right0, final int ngramOrder) {
long left, right;
long pivot;
left = left0;
right = right0 + 1;
final long pivotIndex = (left0 + right0) >>> 1;
pivot = array.get(pivotIndex);//[outerArrayPart(pivotIndex)][innerArrayPart(pivotIndex)];
swap(pivotIndex, left0, array, ngramOrder);
do {
do
left++;
while (left <= right0 && compareLongsRaw(array.get(left), pivot) < 0);
do
right--;
while (compareLongsRaw(array.get(right), pivot) > 0);
if (left < right) {
swap(left, right, array, ngramOrder);
}
} while (left <= right);
swap(left0, right, array, ngramOrder);
if (left0 < right) sort(array, left0, right, ngramOrder);
if (left < right0) sort(array, left, right0, ngramOrder);
}
protected void swap(final long a, final long b, final LongArray array, final int ngramOrder) {
swap(array, a, b);
((CompressibleValueContainer<T>) values).swap(a, b, ngramOrder);
}
protected void swap(final LongArray array, final long a, final long b) {
final long temp = array.get(a);
array.set(a, array.get(b));
array.set(b, temp);
}
@Override
public void trim() {
values.trim();
}
@Override
public void initWithLengths(final List<Long> numNGrams) {
}
@Override
public int getMaxNgramOrder() {
return maps.length;
}
@Override
public Iterable<Entry<T>> getNgramsForOrder(final int ngramOrder) {
return new Iterable<Entry<T>>()
{
@Override
public Iterator<edu.berkeley.nlp.lm.map.NgramMap.Entry<T>> iterator() {
return new Iterator<edu.berkeley.nlp.lm.map.NgramMap.Entry<T>>()
{
long currOffset = 0;
@Override
public boolean hasNext() {
return currOffset < maps[ngramOrder].size();
}
@Override
public edu.berkeley.nlp.lm.map.NgramMap.Entry<T> next() {
final T scratch_ = values.getScratchValue();
long offset = currOffset;
final int[] ngram = new int[ngramOrder + 1];
for (int i = ngramOrder; i >= 0; --i) {
final T scratch = i == ngramOrder ? scratch_ : null;
final long foundKey = decompressSearch(maps[i].compressedKeys, -1, i, scratch, offset);
assert foundKey >= 0;
ngram[reverseTrie ? (ngramOrder - i) : i] = wordOf(foundKey);
offset = contextOffsetOf(foundKey);
}
currOffset++;
return new Entry<T>(ngram, scratch_);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Method not yet implemented");
}
};
}
};
}
@Override
public long getNumNgrams(final int ngramOrder) {
return maps[ngramOrder].size();
}
@Override
public boolean contains(final int[] ngram, final int startPos, final int endPos) {
return getContextOffset(ngram, startPos, endPos, null) >= 0;
}
@Override
public T get(int[] ngram, int startPos, int endPos) {
T val = values.getScratchValue();
final long offset = getContextOffset(ngram, startPos, endPos, val);
if (offset < 0) { return null; }
return val;
}
@Override
public CustomWidthArray getValueStoringArray(final int ngramOrder) {
return null;
}
@Override
public void clearStorage() {
for (int i = 0; i < maps.length; ++i)
maps[i] = null;
}
}