package edu.berkeley.cs.succinct.util.stream;
import edu.berkeley.cs.succinct.util.CommonUtils;
import edu.berkeley.cs.succinct.util.dictionary.Tables;
import edu.berkeley.cs.succinct.util.stream.serops.BitMapOps;
import org.apache.hadoop.fs.FSDataInputStream;
import java.io.IOException;
import static edu.berkeley.cs.succinct.util.DictionaryUtils.*;
public class WaveletTreeStream {
private FSDataInputStream stream;
private long startPos;
public WaveletTreeStream(FSDataInputStream stream, long startPos) throws IOException {
this.stream = stream;
this.startPos = startPos;
}
public long lookup(int contextPos, int cellPos, int startIdx, int endIdx) throws IOException {
stream.seek(startPos);
return waveletTreeLookup(contextPos, cellPos, startIdx, endIdx);
}
private long waveletTreeLookup(int contextPos, int cellPos, int startIdx, int endIdx)
throws IOException {
byte m = stream.readByte();
int left = (int) stream.readLong();
int right = (int) stream.readLong();
int dictPos = (int) stream.getPos();
long p, v;
if (contextPos > m && contextPos <= endIdx) {
if (right == 0) {
return select1(dictPos, cellPos);
}
stream.seek(startPos + right);
p = waveletTreeLookup(contextPos, cellPos, m + 1, endIdx);
v = select1(dictPos, (int) p);
} else {
if (left == 0) {
return select0(dictPos, cellPos);
}
stream.seek(startPos + left);
p = waveletTreeLookup(contextPos, cellPos, startIdx, m);
v = select0(dictPos, (int) p);
}
return v;
}
private long select0(long dictPos, int i) throws IOException {
assert (i >= 0);
RandomAccessLongStream dictBuf = new RandomAccessLongStream(stream, dictPos, Integer.MAX_VALUE);
long size = dictBuf.get();
long val = i + 1;
int sp = 0;
int ep = (int) (size / CommonUtils.two32);
int m;
long r;
int pos = 0;
int blockClass, blockOffset;
long sel;
int lastBlock;
long rankL12, posL12;
int l3Size = (int) ((size / CommonUtils.two32) + 1);
int l12Size = (int) ((size / 2048) + 1);
int basePos = (int) dictBuf.position();
while (sp <= ep) {
m = (sp + ep) / 2;
r = (m * CommonUtils.two32 - dictBuf.get(basePos + m));
if (val > r) {
sp = m + 1;
} else {
ep = m - 1;
}
}
ep = Math.max(ep, 0);
val -= (ep * CommonUtils.two32 - dictBuf.get(basePos + ep));
pos += dictBuf.get(basePos + l3Size + ep);
sp = (int) (ep * CommonUtils.two32 / 2048);
ep = (int) (Math.min(((ep + 1) * CommonUtils.two32 / 2048), Math.ceil((double) size / 2048.0))
- 1);
assert (val <= CommonUtils.two32);
assert (pos >= 0);
dictBuf.position(basePos + 2 * l3Size);
basePos = (int) dictBuf.position();
while (sp <= ep) {
m = (sp + ep) / 2;
r = m * 2048 - GETRANKL2(dictBuf.get(basePos + m));
if (val > r) {
sp = m + 1;
} else {
ep = m - 1;
}
}
ep = Math.max(ep, 0);
sel = (long) (ep) * 2048L;
rankL12 = dictBuf.get(basePos + ep);
posL12 = dictBuf.get(basePos + l12Size + ep);
val -= (ep * 2048 - GETRANKL2(rankL12));
pos += GETPOSL2(posL12);
assert (val <= 2048);
assert (pos >= 0);
r = (512 - GETRANKL1(rankL12, 1));
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 1);
val -= r;
sel += 512;
r = (512 - GETRANKL1(rankL12, 2));
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 2);
val -= r;
sel += 512;
r = (512 - GETRANKL1(rankL12, 3));
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 3);
val -= r;
sel += 512;
}
}
}
dictBuf.position(basePos + 2 * l12Size);
assert (val <= 512);
assert (pos >= 0);
dictBuf.get(); // TODO: Could remove this field altogether
while (true) {
blockClass = (int) BitMapOps.getValPos(dictBuf, pos, 4);
short offsetSize = (short) Tables.offsetBits[blockClass];
pos += 4;
blockOffset = (int) ((blockClass == 0) ? BitMapOps.getBit(dictBuf, pos) * 16 : 0);
pos += offsetSize;
if (val <= (16 - (blockClass + blockOffset))) {
pos -= (4 + offsetSize);
break;
}
val -= (16 - (blockClass + blockOffset));
sel += 16;
}
blockClass = (int) BitMapOps.getValPos(dictBuf, pos, 4);
pos += 4;
blockOffset = (int) BitMapOps.getValPos(dictBuf, pos, Tables.offsetBits[blockClass]);
lastBlock = Tables.decodeTable[blockClass][blockOffset];
long count = 0;
for (i = 0; i < 16; i++) {
if (((lastBlock >> (15 - i)) & 1) == 0) {
count++;
}
if (count == val) {
return sel + i;
}
}
return sel;
}
private long select1(long dictPos, int i) throws IOException {
assert (i >= 0);
RandomAccessLongStream dictBuf = new RandomAccessLongStream(stream, dictPos, Integer.MAX_VALUE);
long size = dictBuf.get();
long val = i + 1;
int sp = 0;
int ep = (int) (size / CommonUtils.two32);
int m;
long r;
int pos = 0;
int blockClass, blockOffset;
long sel;
int lastBlock;
long rankL12, posL12;
int l3Size = (int) ((size / CommonUtils.two32) + 1);
int l12Size = (int) ((size / 2048) + 1);
int basePos = (int) dictBuf.position();
while (sp <= ep) {
m = (sp + ep) / 2;
r = dictBuf.get(basePos + m);
if (val > r) {
sp = m + 1;
} else {
ep = m - 1;
}
}
ep = Math.max(ep, 0);
val -= dictBuf.get(basePos + ep);
pos += dictBuf.get(basePos + l3Size + ep);
sp = (int) (ep * CommonUtils.two32 / 2048);
ep = (int) (Math.min(((ep + 1) * CommonUtils.two32 / 2048), Math.ceil((double) size / 2048.0))
- 1);
assert (val <= CommonUtils.two32);
assert (pos >= 0);
dictBuf.position(basePos + 2 * l3Size);
basePos = (int) dictBuf.position();
while (sp <= ep) {
m = (sp + ep) / 2;
r = GETRANKL2(dictBuf.get(basePos + m));
if (val > r) {
sp = m + 1;
} else {
ep = m - 1;
}
}
ep = Math.max(ep, 0);
sel = (long) (ep) * 2048L;
rankL12 = dictBuf.get(basePos + ep);
posL12 = dictBuf.get(basePos + l12Size + ep);
val -= GETRANKL2(rankL12);
pos += GETPOSL2(posL12);
assert (val <= 2048);
assert (pos >= 0);
r = GETRANKL1(rankL12, 1);
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 1);
val -= r;
sel += 512;
r = GETRANKL1(rankL12, 2);
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 2);
val -= r;
sel += 512;
r = GETRANKL1(rankL12, 3);
if (sel + 512 < size && val > r) {
pos += GETPOSL1(posL12, 3);
val -= r;
sel += 512;
}
}
}
dictBuf.position(basePos + 2 * l12Size);
assert (val <= 512);
assert (pos >= 0);
dictBuf.get(); // TODO: Could remove this field altogether
while (true) {
blockClass = (int) BitMapOps.getValPos(dictBuf, pos, 4);
short offsetSize = (short) Tables.offsetBits[blockClass];
pos += 4;
blockOffset = (int) ((blockClass == 0) ? BitMapOps.getBit(dictBuf, pos) * 16 : 0);
pos += offsetSize;
if (val <= (blockClass + blockOffset)) {
pos -= (4 + offsetSize);
break;
}
val -= (blockClass + blockOffset);
sel += 16;
}
blockClass = (int) BitMapOps.getValPos(dictBuf, pos, 4);
pos += 4;
blockOffset = (int) BitMapOps.getValPos(dictBuf, pos, Tables.offsetBits[blockClass]);
lastBlock = Tables.decodeTable[blockClass][blockOffset];
long count = 0;
for (i = 0; i < 16; i++) {
if (((lastBlock >>> (15 - i)) & 1) == 1) {
count++;
}
if (count == val) {
return sel + i;
}
}
return sel;
}
public void close() throws IOException {
stream.close();
}
}