package edu.berkeley.nlp.lm.array; import java.io.ObjectStreamException; import java.io.Serializable; /** * An array with a custom word "width" in bits. Borrows heavily from Sux4J * (http://sux.dsi.unimi.it/) * * @author adampauls * */ public final class CustomWidthArray implements Serializable { public int getKeyWidth() { return keyWidth; } private static final long serialVersionUID = 1L; private final static int LOG2_BITS_PER_WORD = 6; private final static int BITS_PER_WORD = 1 << LOG2_BITS_PER_WORD; private final static int WORD_MASK = BITS_PER_WORD - 1; private long size; private final int keyWidth; private final int fullWidth; final long widthDiff; private final LongArray data; private final static long numLongs(final long size) { return ((size + WORD_MASK) >>> LOG2_BITS_PER_WORD); } private final static long word(final long index) { return (index >>> LOG2_BITS_PER_WORD); } private final static long bit(final long index) { return (index & WORD_MASK); } private final static long mask(final long index) { return 1L << (index & WORD_MASK); } public CustomWidthArray(final long numWords, final int keyWidth) { this(numWords, keyWidth, keyWidth); } public CustomWidthArray(final long numWords, final int keyWidth, final int fullWidth) { assert keyWidth > 0; assert fullWidth > 0; this.keyWidth = keyWidth; this.fullWidth = fullWidth; this.widthDiff = Long.SIZE - keyWidth; final long numBits = numWords * fullWidth; data = new LongArray(numLongs(numBits));// new long[numLongs(numBits)]; size = 0; } private long length() { return size; } public void ensureCapacity(final long numWords) { final long numBits = numWords * fullWidth; final long numLongs = numLongs(numBits); data.ensureCapacity(numLongs); if (numLongs > data.size()) data.setAndGrowIfNeeded(numLongs - 1, 0); } public void trim() { trimToSize(size); } /** * @param sizeHere */ public void trimToSize(final long sizeHere) { final long numBits = sizeHere * fullWidth; data.trimToSize(numLongs(numBits)); } private void rangeCheck(final long index) { if (index >= length()) { // throw new IndexOutOfBoundsException("Index (" + index + ") is greater than length (" + (length()) + ")"); } } public boolean getBit(final long index) { rangeCheck(index); return (data.get(word(index)) & mask(index)) != 0; } public void clear(final long index) { rangeCheck(index); data.set(word(index), data.get(word(index)) & ~mask(index)); } private long getLong(final long from, final long l) { if (l == Long.SIZE) return 0; final long startWord = word(from); final long startBit = bit(from); if (startBit <= l) return data.get(startWord) << l - startBit >>> l; else return data.get(startWord) >>> startBit | data.get(startWord + 1) << Long.SIZE + l - startBit >>> l; } public boolean add(final long value) { return addHelp(value, true); } public boolean addWithFixedCapacity(final long value) { return addHelp(value, false); } /** * @param value * @return */ private boolean addHelp(final long value, final boolean growCapacity) { assert fullWidth == keyWidth; final long length = this.size * fullWidth; final long startWord = word(length); final long startBit = bit(length); if (growCapacity) ensureCapacity(this.size + 1); if (startBit + keyWidth <= Long.SIZE) data.set(startWord, data.get(startWord) | (value << startBit)); else { data.set(startWord, data.get(startWord) | (value << startBit)); data.set(startWord + 1, value >>> BITS_PER_WORD - startBit); } this.size++; return true; } public long get(final long index) { return getHelp(index, 0, keyWidth); } public long get(final long index, int offset, int width) { return getHelp(index, offset, width); } /** * @param index * @return */ private long getHelp(final long index, int offset, int width) { final long start = index * fullWidth + offset; return getLong(start, Long.SIZE - width); } public static int numBitsNeeded(final long n) { if (n == 0) return 1; if (Long.bitCount(n) == 1) return Long.numberOfTrailingZeros(n) + 1; else return Long.SIZE - Long.numberOfLeadingZeros(n - 1); } public void set(final long index, final long value) { rangeCheck(index); final int offset = 0; final int width = keyWidth; setHelp(index, value, offset, width); } public void set(final long index, final long value, final int offset, final int width) { rangeCheck(index); setHelp(index, value, offset, width); } /** * @param index * @param value * @param offset */ private void setHelp(final long index, final long value, final int offset, final int width) { assert numBitsNeeded(value) <= width : "Value " + value + " bits " + width; final long start = index * fullWidth + offset; final long startWord = word(start); final long endWord = word(start + width - 1); final long startBit = bit(start); final long fullMask = width == Long.SIZE ? -1L : ((1L << width) - 1); if (startWord == endWord) { long startWordLong = data.get(startWord); startWordLong &= ~(fullMask << startBit); startWordLong |= value << startBit; data.set(startWord, startWordLong); assert value == (startWordLong >>> startBit & fullMask) : startWord + " " + startBit + " " + value; } else { // Here startBit > 0. long startWordLong = data.get(startWord); startWordLong &= ((1L << startBit) - 1); startWordLong |= (value << startBit); data.set(startWord, startWordLong); long endWordLong = data.get(endWord); endWordLong &= (-(1L << width - BITS_PER_WORD + startBit)); endWordLong |= (value >>> BITS_PER_WORD - startBit); data.set(endWord, endWordLong); assert value == (startWordLong >>> startBit | endWordLong << (BITS_PER_WORD - startBit) & fullMask); } } public void setAndGrowIfNeeded(final long pos, final long value) { if (pos >= size) { ensureCapacity(pos + 2); this.size = pos + 1; } set(pos, value); } public void setAndGrowIfNeeded(final long pos, final long value, final int offset, final int width) { if (pos >= size) { ensureCapacity(pos + 2); this.size = pos + 1; } set(pos, value, offset, width); } public long size() { return length(); } public void fill(final long l, final long n) { final long numBits = n * fullWidth; final long numLongs = numLongs(numBits); data.fill(l, numLongs); size = Math.max(n, size); } public long linearSearch(final long key, final long rangeStart, final long rangeEnd, final long startIndex, final long emptyKey, final boolean returnFirstEmptyIndex) { for (long i = startIndex; i < rangeEnd; ++i) { final long searchKey = getHelp(i, 0, keyWidth); if (searchKey == key) return i; if (searchKey == emptyKey) return returnFirstEmptyIndex ? i : -1L; } for (long i = rangeStart; i < startIndex; ++i) { final long searchKey = getHelp(i, 0, keyWidth); if (searchKey == key) return i; if (searchKey == emptyKey) return returnFirstEmptyIndex ? i : -1L; } return -1L; } public void incrementCount(final long index, final long count) { if (index >= size()) { setAndGrowIfNeeded(index, count); } else { final long curr = get(index); set(index, curr + count); } } public int getFullWidth() { return fullWidth; } }