/*
* Copyright (C) 2014 Indeed Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the
* License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.indeed.imhotep;
import com.indeed.imhotep.api.FTGSIterator;
import com.indeed.util.core.io.Closeables2;
import org.apache.log4j.Logger;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.util.Arrays;
import java.util.Collection;
/**
* @author jsgroth
*/
public final class FastIntFTGSMerger implements FTGSIterator {
private static final Logger log = Logger.getLogger(FastIntFTGSMerger.class);
protected final FTGSIterator[] iterators;
private int numFieldIterators = 0;
private String fieldName;
protected boolean fieldIsIntType;
protected long termIntVal;
private boolean done;
private final int numGroups;
private final Closeable doneCallback;
protected final GSVector accumulatedVec;
public FastIntFTGSMerger(Collection<? extends FTGSIterator> iterators, int numStats, int numGroups, @Nullable Closeable doneCallback) {
this.numGroups = numGroups;
this.doneCallback = doneCallback;
this.iterators = iterators.toArray(new FTGSIterator[iterators.size()]);
done = false;
accumulatedVec = new GSVector(numStats, numGroups);
}
@Override
public final boolean nextField() {
if (done) return false;
final FTGSIterator first = iterators[0];
if (!first.nextField()) {
for (int i = 1; i < iterators.length; ++i) {
if (iterators[i].nextField()) {
throw new IllegalArgumentException("sub iterator fields do not match");
}
}
close();
return false;
}
fieldName = first.fieldName();
fieldIsIntType = first.fieldIsIntType();
numFieldIterators = iterators.length;
for (int i = 1; i < iterators.length; ++i) {
final FTGSIterator itr = iterators[i];
if (!itr.nextField() || !itr.fieldName().equals(fieldName) || itr.fieldIsIntType() != fieldIsIntType) {
throw new IllegalArgumentException("sub iterator fields do not match");
}
}
for (int i = iterators.length-1; i >= 0; i--) {
while (true) {
if (!iterators[i].nextTerm()) {
numFieldIterators--;
swap(iterators, i, numFieldIterators);
break;
}
if (!iterators[i].nextGroup()) {
continue;
}
break;
}
}
accumulatedVec.resetNewField();
return true;
}
private static void swap(Object[] array, int indexA, int indexB) {
final Object a = array[indexA];
array[indexA] = array[indexB];
array[indexB] = a;
}
@Override
public final String fieldName() {
return fieldName;
}
@Override
public final boolean fieldIsIntType() {
return fieldIsIntType;
}
private void refill() {
long minBaseTermGroup = Long.MAX_VALUE;
for (int i = 0; i < numFieldIterators; i++) {
final long baseTermGroup = (iterators[i].termIntVal()*numGroups+iterators[i].group())&~0xFFF;
if (baseTermGroup < minBaseTermGroup) {
minBaseTermGroup = baseTermGroup;
}
}
accumulatedVec.reset();
for (int i = numFieldIterators-1; i >= 0; i--) {
final long baseTermGroup = (iterators[i].termIntVal()*numGroups+iterators[i].group())&~0xFFF;
if (baseTermGroup == minBaseTermGroup) {
if (!accumulatedVec.mergeFromFtgs(iterators[i])) {
numFieldIterators--;
swap(iterators, i, numFieldIterators);
}
}
}
}
@Override
public boolean nextTerm() {
while (true) {
if (accumulatedVec.nextTerm()) {
return true;
}
if (numFieldIterators == 0) return false;
refill();
}
}
@Override
public final long termDocFreq() {
return 1;
}
@Override
public final long termIntVal() {
return accumulatedVec.term;
}
public String termStringVal() {
throw new UnsupportedOperationException();
}
@Override
public final boolean nextGroup() {
if (accumulatedVec.nextGroup()) {
return true;
}
if (accumulatedVec.bitset1 == 0 && accumulatedVec.bitset2[accumulatedVec.iteratorIndex] == 0) {
if (numFieldIterators == 0) return false;
refill();
return accumulatedVec.nextGroup();
}
return false;
}
@Override
public final int group() {
return accumulatedVec.group();
}
@Override
public final void groupStats(long[] stats) {
accumulatedVec.groupStats(stats);
}
@Override
public synchronized void close() {
if (!done) {
done = true;
Closeables2.closeAll(log, Closeables2.forArray(log, iterators), doneCallback);
}
}
protected static void swap(final int[] a, final int b, final int e) {
final int t = a[b];
a[b] = a[e];
a[e] = t;
}
static final class GSVector {
long bitset1;
final long[] bitset2 = new long[64];
final long[] metrics;
private final int numStats;
private final int numGroups;
private long base = -1;
private int iteratorIndex = -1;
private long term = 0;
private int group = -1;
private final long[] statBuf;
public GSVector(final int numStats, final int numGroups) {
this.numStats = numStats;
this.numGroups = numGroups;
metrics = new long[numStats*4096];
statBuf = new long[numStats];
}
public void resetNewField() {
reset();
term = Long.MIN_VALUE;
}
public void reset() {
if (group >= 0) {
final int start = (int)(term*numGroups+group-base) * numStats;
Arrays.fill(metrics, start, start+numStats, 0);
}
while (bitset1 != 0) {
final long lsb1 = bitset1 & -bitset1;
bitset1 ^= lsb1;
final int index1 = Long.bitCount(lsb1-1);
while (bitset2[index1] != 0) {
final long lsb2 = bitset2[index1] & -bitset2[index1];
bitset2[index1] ^= lsb2;
final int index2 = (index1<<6)+Long.bitCount(lsb2-1);
Arrays.fill(metrics, index2 * numStats, index2 * numStats + numStats, 0);
}
}
base = -1;
iteratorIndex = -1;
group = -1;
}
public boolean mergeFromFtgs(FTGSIterator ftgs) {
final long termGroup = ftgs.termIntVal()*numGroups+ftgs.group();
final long newBase = termGroup & ~0xFFF;
if (base != -1 && base != newBase) {
throw new IllegalStateException();
}
base = newBase;
int termGroupOffset = (int)(termGroup-base);
do {
ftgs.groupStats(statBuf);
for (int i = 0; i < numStats; i++) {
metrics[termGroupOffset*numStats+i] += statBuf[i];
}
final int bitset2index = termGroupOffset>>>6;
bitset1 |= 1L<<bitset2index;
bitset2[bitset2index] |= 1L<<(termGroupOffset&0x3F);
while (true) {
if (!ftgs.nextGroup()) {
if (!ftgs.nextTerm()) {
return false;
}
continue;
}
break;
}
termGroupOffset = (int) (ftgs.termIntVal()*numGroups+ftgs.group() - base);
} while (termGroupOffset < 4096);
return true;
}
//clears bitsets and metrics as it iterates
public boolean nextGroup() {
if (group >= 0) {
final int start = (int)((term*numGroups+group-base) * numStats);
Arrays.fill(metrics, start, start+numStats, 0);
}
if (iteratorIndex < 0 || bitset2[iteratorIndex] == 0) {
if (bitset1 == 0) {
return false;
}
final long lsb1 = bitset1 & -bitset1;
bitset1 ^= lsb1;
iteratorIndex = Long.bitCount(lsb1-1);
}
final long lsb2 = bitset2[iteratorIndex] & -bitset2[iteratorIndex];
final int bitset2index = Long.bitCount(lsb2 - 1);
final int groupOffset = (iteratorIndex << 6) | bitset2index;
final long termGroup = base+groupOffset;
if (termGroup >= term*numGroups+numGroups) {
return false;
}
bitset2[iteratorIndex] ^= lsb2;
group = (int) (termGroup-term*numGroups);
return true;
}
public boolean nextTerm() {
while (nextGroup()) {
//finish previous term
}
if (iteratorIndex < 0 || bitset2[iteratorIndex] == 0) {
if (bitset1 == 0) {
return false;
}
final long lsb1 = bitset1 & -bitset1;
bitset1 ^= lsb1;
iteratorIndex = Long.bitCount(lsb1-1);
}
final long lsb2 = bitset2[iteratorIndex] & -bitset2[iteratorIndex];
final int bitset2index = Long.bitCount(lsb2 - 1);
final int groupOffset = (iteratorIndex << 6) | bitset2index;
final long termGroup = base+groupOffset;
term = lfloordiv(termGroup, numGroups);
group = -1;
return true;
}
public static long lfloordiv( long n, long d ) {
if (n >= 0) {
return n / d;
} else {
return ~(~n / d);
}
}
public long term() {
return term;
}
public int group() {
return group;
}
public void groupStats(long buf[]) {
System.arraycopy(metrics, (int) ((term*numGroups+group-base)*numStats), buf, 0, numStats);
}
public String toString() {
String ret = String.format("%64s", Long.toBinaryString(bitset1)).replace(' ', '0')+"\n";
long tmpBitset1 = bitset1;
while (tmpBitset1 != 0) {
long lsb = tmpBitset1 & -tmpBitset1;
int index = Long.bitCount(lsb-1);
tmpBitset1 ^= lsb;
ret += String.format("%64s", Long.toBinaryString(bitset2[index])).replace(' ', '0')+"\n";
}
return ret;
}
}
}