/*
* Copyright (C) 2014 Indeed Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the
* License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.indeed.imhotep;
import com.indeed.util.core.io.Closeables2;
import com.indeed.imhotep.api.FTGSIterator;
import org.apache.log4j.Logger;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.util.Arrays;
import java.util.Collection;
/**
* @author jsgroth
*/
public abstract class AbstractFTGSMerger implements FTGSIterator {
private static final Logger log = Logger.getLogger(AbstractFTGSMerger.class);
protected final FTGSIterator[] iterators;
private final int numIterators;
protected final int[] fieldIterators;
protected int numFieldIterators;
protected final int[] termIterators;
protected final int[] termIteratorIndexes;
protected int numTermIterators;
protected int termIteratorsRemaining;
private String fieldName;
protected boolean fieldIsIntType;
protected long termIntVal;
private boolean done;
private final Closeable doneCallback;
protected final GSVector accumulatedVec;
protected AbstractFTGSMerger(Collection<? extends FTGSIterator> iterators, int numStats, @Nullable Closeable doneCallback) {
this.doneCallback = doneCallback;
numIterators = iterators.size();
this.iterators = iterators.toArray(new FTGSIterator[numIterators]);
fieldIterators = new int[numIterators];
numFieldIterators = 0;
termIterators = new int[numIterators];
termIteratorIndexes = new int[numIterators];
numTermIterators = 0;
done = false;
accumulatedVec = new GSVector(numStats);
}
@Override
public final boolean nextField() {
if (done) return false;
numFieldIterators = 0;
final FTGSIterator first = iterators[0];
final boolean firstHasNextField = first.nextField();
for (int i = 1; i < numIterators; ++i) {
if (iterators[i].nextField() != firstHasNextField) {
throw new IllegalArgumentException("sub iterator fields do not match");
}
}
if (!firstHasNextField) {
close();
return false;
}
fieldName = first.fieldName();
fieldIsIntType = first.fieldIsIntType();
numFieldIterators = 0;
if (first.nextTerm()) {
fieldIterators[numFieldIterators++] = 0;
}
for (int i = 1; i < numIterators; ++i) {
final FTGSIterator itr = iterators[i];
if (!itr.fieldName().equals(fieldName) || itr.fieldIsIntType() != fieldIsIntType) {
throw new IllegalArgumentException("sub iterator fields do not match");
}
if (itr.nextTerm()) {
fieldIterators[numFieldIterators++] = i;
}
}
numTermIterators = 0;
return true;
}
@Override
public final String fieldName() {
return fieldName;
}
@Override
public final boolean fieldIsIntType() {
return fieldIsIntType;
}
@Override
public final long termDocFreq() {
long ret = 0L;
for (int i = 0; i < termIteratorsRemaining; ++i) {
ret += iterators[termIterators[i]].termDocFreq();
}
return ret;
}
@Override
public final long termIntVal() {
return termIntVal;
}
@Override
public final boolean nextGroup() {
while (true) {
if (accumulatedVec.nextGroup()) {
return true;
}
if (termIteratorsRemaining == 0) return false;
calculateNextGroupBatch();
}
}
private void calculateNextGroupBatch() {
int baseGroup = Integer.MAX_VALUE;
for (int i = 0; i < termIteratorsRemaining; ++i) {
final FTGSIterator itr = iterators[termIterators[i]];
final int group = itr.group()&0xFFFFF000;
if (group < baseGroup) {
baseGroup = group;
}
}
accumulatedVec.reset();
for (int i = 0; i < termIteratorsRemaining; ++i) {
final FTGSIterator itr = iterators[termIterators[i]];
if ((itr.group()&0xFFFFF000) == baseGroup) {
if (!accumulatedVec.mergeFromFtgs(itr)) {
swap(termIterators, i, --termIteratorsRemaining);
swap(termIteratorIndexes, i, termIteratorsRemaining);
--i;
}
}
}
}
@Override
public final int group() {
return accumulatedVec.group();
}
@Override
public final void groupStats(long[] stats) {
accumulatedVec.groupStats(stats);
}
@Override
public synchronized void close() {
if (!done) {
done = true;
Closeables2.closeAll(log, Closeables2.forArray(log, iterators), doneCallback);
}
}
protected static void swap(final int[] a, final int b, final int e) {
final int t = a[b];
a[b] = a[e];
a[e] = t;
}
static final class GSVector {
long bitset1;
final long[] bitset2 = new long[64];
final long[] metrics;
private final int numStats;
private int baseGroup = -1;
private int iteratorIndex;
private int group;
private final long[] statBuf;
public GSVector(final int numStats) {
this.numStats = numStats;
metrics = new long[numStats*4096];
statBuf = new long[numStats];
}
public void reset() {
if (group >= 0) {
final int start = (group - baseGroup) * numStats;
Arrays.fill(metrics, start, start+numStats, 0);
}
while (bitset1 != 0) {
final long lsb1 = bitset1 & -bitset1;
bitset1 ^= lsb1;
final int index1 = Long.bitCount(lsb1-1);
while (bitset2[index1] != 0) {
final long lsb2 = bitset2[index1] & -bitset2[index1];
bitset2[index1] ^= lsb2;
final int index2 = (index1<<6)+Long.bitCount(lsb2-1);
Arrays.fill(metrics, index2 * numStats, index2 * numStats + numStats, 0);
}
}
baseGroup = -1;
iteratorIndex = -1;
group = -1;
}
public boolean mergeFromFtgs(FTGSIterator ftgs) {
int group = ftgs.group();
final int newBaseGroup = group & 0xFFFFF000;
if (baseGroup != -1 && baseGroup != newBaseGroup) {
throw new IllegalStateException();
}
baseGroup = newBaseGroup;
group -= baseGroup;
do {
ftgs.groupStats(statBuf);
for (int i = 0; i < numStats; i++) {
metrics[group*numStats+i] += statBuf[i];
}
final int bitset2index = group>>>6;
bitset1 |= 1L<<bitset2index;
bitset2[bitset2index] |= 1L<<(group&0x3F);
if (!ftgs.nextGroup()) return false;
group = ftgs.group()-baseGroup;
} while (group < 4096);
return true;
}
//clears bitsets and metrics as it iterates
public boolean nextGroup() {
if (group >= 0) {
final int start = (group - baseGroup) * numStats;
Arrays.fill(metrics, start, start+numStats, 0);
}
if (iteratorIndex < 0 || bitset2[iteratorIndex] == 0) {
if (bitset1 == 0) {
return false;
}
final long lsb1 = bitset1 & -bitset1;
bitset1 ^= lsb1;
iteratorIndex = Long.bitCount(lsb1-1);
}
final long lsb2 = bitset2[iteratorIndex] & -bitset2[iteratorIndex];
bitset2[iteratorIndex] ^= lsb2;
final int bitset2index = Long.bitCount(lsb2 - 1);
final int groupOffset = (iteratorIndex << 6) | bitset2index;
group = baseGroup+groupOffset;
return true;
}
public int group() {
return group;
}
public void groupStats(long buf[]) {
System.arraycopy(metrics, (group-baseGroup)*numStats, buf, 0, numStats);
}
}
}