package org.apache.cassandra.hadoop2.multiquery;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import com.datastax.driver.core.DataType;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.PeekingIterator;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Iterator that operates over multiple Cassandra Row Iterators at once.
*
* An instance of this class will group together rows that are identical over a given set of
* columns (e.g., the partitioning columns, the clustering columns, etc.). The list of columns to
* use for comparing must contain *at least* all of the partitioning columns.
*
* The input rows are assumed to be ordered by these collections of columns in the iterators.
*
* TODO: Use a merge sort over the iterators to group together rows that have the same values for
* the columns that we care about. Sort first by token, and then by column values.
*
* The primary key must be the same for all of the rows.
*/
class MultiRowIterator implements Iterator<List<Row>> {
private static final Logger LOG = LoggerFactory.getLogger(MultiRowIterator.class);
private final PeekingIterator<Row> mRowIterator;
private final List<Pair<String, DataType>> mColumnsToCompare;
private final RowComparator mRowComparator;
/**
* Create a multi-row iterator.
*
* @param resultSets A list of result sets. We will return rows from this iterator in the same
* order as the result sets.
* @param columnsToCompare A list of columns, in order, for comparing and sorting rows.
*/
public MultiRowIterator(
List<ResultSet> resultSets,
// TODO: Maybe use ColumnMetadata here instead of String?
// Would allow storing "type" as well as name...
List<Pair<String, DataType>> columnsToCompare) {
List<PeekingIterator<Row>> rowIterators = Lists.newArrayList();
for (ResultSet resultSet : resultSets) {
rowIterators.add(Iterators.peekingIterator(resultSet.iterator()));
}
mColumnsToCompare = columnsToCompare;
mRowComparator = new RowComparator();
mRowIterator = Iterators.peekingIterator(Iterators.mergeSorted(rowIterators, mRowComparator));
}
/** {@inheritDoc} */
public boolean hasNext() {
return mRowIterator.hasNext();
}
/** {@inheritDoc} */
public List<Row> next() {
// Get the first row in our iterator.
Row firstRow = mRowIterator.next();
LOG.debug("First row = " + firstRow);
List<Row> rowsToReturnTogether = Lists.newArrayList(firstRow);
// Continue popping rows off of the iterator as long as all of the columns that we want to
// compare are the same between this row and the next.
while (mRowIterator.hasNext() &&
mRowComparator.compare(firstRow, mRowIterator.peek()) == 0) {
LOG.debug("Next row = " + mRowIterator.peek());
rowsToReturnTogether.add(mRowIterator.next());
}
// Return all of these rows that had the same partition key + set of clustering columns.
return rowsToReturnTogether;
}
/** {@inheritDoc} */
public void remove() {
throw new UnsupportedOperationException("Cannot remove from this iterator!");
}
// -----------------------------------------------------------------------------------------------
/**
* Compares {@link com.datastax.driver.core.Row} objects by their Kiji Entity ID. The Row objects
* must be from the same table. The Rows are first compared by their partion key token, and then
* by the entity ID components they contain.
*/
private final class RowComparator implements Comparator<Row> {
/** {@inheritDoc} */
@Override
public int compare(Row o1, Row o2) {
ComparisonChain chain = ComparisonChain.start();
for (Pair<String, DataType> columnAndType : mColumnsToCompare) {
String columnName = columnAndType.getLeft();
DataType dataType = columnAndType.getRight();
switch (dataType.getName()) {
case BOOLEAN:
chain = chain.compare(
Boolean.toString(o1.getBool(columnName)),
Boolean.toString(o2.getBool(columnName)));
break;
case INT:
chain = chain.compare(o1.getInt(columnName), o2.getInt(columnName));
break;
case BIGINT:
case COUNTER:
chain = chain.compare(o1.getLong(columnName), o2.getLong(columnName));
break;
case DOUBLE:
chain = chain.compare(o1.getDouble(columnName), o2.getDouble(columnName));
break;
case FLOAT:
chain = chain.compare(o1.getFloat(columnName), o2.getFloat(columnName));
break;
case BLOB:
chain = chain.compare(o1.getBytes(columnName), o2.getBytes(columnName));
break;
case VARCHAR:
case TEXT:
case ASCII:
chain = chain.compare(o1.getString(columnName), o2.getString(columnName));
break;
case VARINT:
chain = chain.compare(o1.getVarint(columnName), o2.getVarint(columnName));
break;
case DECIMAL:
chain = chain.compare(o1.getDecimal(columnName), o2.getDecimal(columnName));
break;
case UUID:
case TIMEUUID:
chain = chain.compare(o1.getUUID(columnName), o2.getUUID(columnName));
break;
case INET:
chain = chain.compare(
o1.getInet(columnName).toString(), o2.getInet(columnName).toString());
break;
default:
throw new UnsupportedOperationException("Cannot sort by " + dataType.getName() + "!");
}
}
return chain.result();
}
}
}