/*
* JBoss, Home of Professional Open Source.
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership. Some portions may be licensed
* to Red Hat, Inc. under one or more contributor license agreements.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301 USA.
*/
package org.teiid.query.processor.relational;
import static org.teiid.query.analysis.AnalysisRecord.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import org.teiid.api.exception.query.ExpressionEvaluationException;
import org.teiid.api.exception.query.FunctionExecutionException;
import org.teiid.client.plan.PlanNode;
import org.teiid.common.buffer.BlockedException;
import org.teiid.common.buffer.BufferManager;
import org.teiid.common.buffer.STree;
import org.teiid.common.buffer.STree.InsertMode;
import org.teiid.common.buffer.TupleBatch;
import org.teiid.common.buffer.TupleBuffer;
import org.teiid.common.buffer.TupleSource;
import org.teiid.core.TeiidComponentException;
import org.teiid.core.TeiidProcessingException;
import org.teiid.core.TeiidRuntimeException;
import org.teiid.core.types.DataTypeManager;
import org.teiid.language.SortSpecification.NullOrdering;
import org.teiid.query.eval.Evaluator;
import org.teiid.query.function.aggregate.*;
import org.teiid.query.processor.BatchCollector;
import org.teiid.query.processor.BatchCollector.BatchProducer;
import org.teiid.query.processor.ProcessorDataManager;
import org.teiid.query.processor.relational.SortUtility.Mode;
import org.teiid.query.sql.LanguageObject;
import org.teiid.query.sql.lang.OrderBy;
import org.teiid.query.sql.lang.OrderByItem;
import org.teiid.query.sql.symbol.AggregateSymbol;
import org.teiid.query.sql.symbol.AggregateSymbol.Type;
import org.teiid.query.sql.symbol.ElementSymbol;
import org.teiid.query.sql.symbol.Expression;
import org.teiid.query.sql.symbol.TextLine;
import org.teiid.query.sql.util.SymbolMap;
import org.teiid.query.util.CommandContext;
public class GroupingNode extends SubqueryAwareRelationalNode {
static class ProjectingTupleSource extends
BatchCollector.BatchProducerTupleSource {
private Evaluator eval;
private List<Expression> collectedExpressions;
private int[] projectionIndexes;
ProjectingTupleSource(BatchProducer sourceNode, Evaluator eval, List<Expression> expressions, Map<Expression, Integer> elementMap) {
super(sourceNode);
this.eval = eval;
this.collectedExpressions = expressions;
this.projectionIndexes = new int[this.collectedExpressions.size()];
Arrays.fill(this.projectionIndexes, -1);
for (int i = 0; i < expressions.size(); i++) {
Integer index = elementMap.get(expressions.get(i));
if(index != null) {
projectionIndexes[i] = index;
}
}
}
@Override
protected List<Object> updateTuple(List<?> tuple) throws ExpressionEvaluationException, BlockedException, TeiidComponentException {
int columns = collectedExpressions.size();
List<Object> exprTuple = new ArrayList<Object>(columns);
for(int col = 0; col<columns; col++) {
int index = projectionIndexes[col];
Object value = null;
if (index != -1) {
value = tuple.get(index);
} else {
// The following call may throw BlockedException, but all state to this point
// is saved in class variables so we can start over on building this tuple
value = eval.evaluate(collectedExpressions.get(col), tuple);
}
exprTuple.add(value);
}
return exprTuple;
}
}
// Grouping columns set by the planner
private List<OrderByItem> orderBy;
private boolean removeDuplicates;
private SymbolMap outputMapping;
// Collection phase
private int phase = COLLECTION;
private Map<Expression, Integer> elementMap; // Map of incoming symbol to index in source elements
private LinkedHashMap<Expression, Integer> collectedExpressions; // Collected Expressions
private int distinctCols = -1;
// Sort phase
private SortUtility sortUtility;
private TupleBuffer sortBuffer;
private TupleSource groupTupleSource;
// Group phase
private AggregateFunction[][] functions;
private List<?> lastRow;
private List<?> currentGroupTuple;
// Group sort
private STree tree;
private AggregateFunction[] groupSortfunctions;
private int[] accumulatorStateCount;
private TupleSource groupSortTupleSource;
private int[] projection;
private static final int COLLECTION = 1;
private static final int SORT = 2;
private static final int GROUP = 3;
private static final int GROUP_SORT = 4;
private static final int GROUP_SORT_OUTPUT = 5;
private int[] indexes;
private boolean rollup;
private HashMap<Integer, Integer> indexMap;
public GroupingNode(int nodeID) {
super(nodeID);
}
public void reset() {
super.reset();
phase = COLLECTION;
sortUtility = null;
sortBuffer = null;
lastRow = null;
currentGroupTuple = null;
if (this.functions != null) {
for (AggregateFunction[] functions : this.functions) {
for (AggregateFunction function : functions) {
function.reset();
}
}
}
}
public void setRemoveDuplicates(boolean removeDuplicates) {
this.removeDuplicates = removeDuplicates;
}
public void setOrderBy(List<OrderByItem> orderBy) {
this.orderBy = orderBy;
}
public void setOutputMapping(SymbolMap outputMapping) {
this.outputMapping = outputMapping;
}
@Override
public void initialize(CommandContext context, BufferManager bufferManager,
ProcessorDataManager dataMgr) {
super.initialize(context, bufferManager, dataMgr);
if (this.functions != null) {
return;
}
// Incoming elements and lookup map for evaluating expressions
List<? extends Expression> sourceElements = this.getChildren()[0].getElements();
this.elementMap = createLookupMap(sourceElements);
this.collectedExpressions = new LinkedHashMap<Expression, Integer>();
// List should contain all grouping columns / expressions as we need those for sorting
if(this.orderBy != null) {
for (OrderByItem item : this.orderBy) {
Expression ex = SymbolMap.getExpression(item.getSymbol());
getIndex(ex, this.collectedExpressions);
}
if (removeDuplicates) {
for (Expression ses : sourceElements) {
getIndex(ses, collectedExpressions);
}
distinctCols = collectedExpressions.size();
}
}
// Construct aggregate function state accumulators
functions = new AggregateFunction[getElements().size()][];
for(int i=0; i<getElements().size(); i++) {
Expression symbol = getElements().get(i);
if (this.outputMapping != null) {
symbol = outputMapping.getMappedExpression((ElementSymbol)symbol);
}
Class<?> outputType = symbol.getType();
if(symbol instanceof AggregateSymbol) {
AggregateSymbol aggSymbol = (AggregateSymbol) symbol;
functions[i] = new AggregateFunction[rollup?orderBy.size()+1:1];
for (int j = 0; j < functions[i].length; j++) {
functions[i][j] = initAccumulator(aggSymbol, this, this.collectedExpressions);
}
} else {
AggregateFunction af = new ConstantFunction();
af.setArgIndexes(new int[] {this.collectedExpressions.get(symbol)});
af.initialize(outputType, new Class<?>[]{symbol.getType()});
functions[i] = new AggregateFunction[] {af};
}
}
}
static Integer getIndex(Expression ex, LinkedHashMap<Expression, Integer> expressionIndexes) {
Integer index = expressionIndexes.get(ex);
if (index == null) {
index = expressionIndexes.size();
expressionIndexes.put(ex, index);
}
return index;
}
static AggregateFunction initAccumulator(AggregateSymbol aggSymbol,
RelationalNode node, LinkedHashMap<Expression, Integer> expressionIndexes) {
int[] argIndexes = new int[aggSymbol.getArgs().length];
AggregateFunction result = null;
Expression[] args = aggSymbol.getArgs();
Class<?>[] inputTypes = new Class[args.length];
for (int j = 0; j < args.length; j++) {
inputTypes[j] = args[j].getType();
argIndexes[j] = getIndex(args[j], expressionIndexes);
}
Type function = aggSymbol.getAggregateFunction();
switch (function) {
case RANK:
case DENSE_RANK:
result = new RankingFunction(function);
break;
case ROW_NUMBER: //same as count(*)
case COUNT:
result = new Count();
break;
case SUM:
result = new Sum();
break;
case AVG:
result = new Avg();
break;
case MIN:
result = new Min();
break;
case MAX:
result = new Max();
break;
case XMLAGG:
result = new XMLAgg();
break;
case ARRAY_AGG:
result = new ArrayAgg();
break;
case JSONARRAY_AGG:
result = new JSONArrayAgg();
break;
case TEXTAGG:
result = new TextAgg((TextLine)args[0]);
break;
case STRING_AGG:
result = new StringAgg(aggSymbol.getType() == DataTypeManager.DefaultDataClasses.BLOB);
break;
case FIRST_VALUE:
result = new FirstLastValue(aggSymbol.getType(), true);
break;
case LAST_VALUE:
result = new FirstLastValue(aggSymbol.getType(), false);
break;
case LEAD:
case LAG:
result = new LeadLagValue();
break;
case USER_DEFINED:
try {
result = new UserDefined(aggSymbol.getFunctionDescriptor());
} catch (FunctionExecutionException e) {
throw new TeiidRuntimeException(e);
}
break;
default:
result = new StatsFunction(function);
}
if (aggSymbol.getOrderBy() != null) {
int numOrderByItems = aggSymbol.getOrderBy().getOrderByItems().size();
List<OrderByItem> orderByItems = new ArrayList<OrderByItem>(numOrderByItems);
List<ElementSymbol> schema = createSortSchema(result, inputTypes);
argIndexes = Arrays.copyOf(argIndexes, argIndexes.length + numOrderByItems);
for (ListIterator<OrderByItem> iterator = aggSymbol.getOrderBy().getOrderByItems().listIterator(); iterator.hasNext();) {
OrderByItem item = iterator.next();
argIndexes[args.length + iterator.previousIndex()] = getIndex(item.getSymbol(), expressionIndexes);
ElementSymbol element = new ElementSymbol(String.valueOf(iterator.previousIndex()));
element.setType(item.getSymbol().getType());
schema.add(element);
OrderByItem newItem = item.clone();
newItem.setSymbol(element);
orderByItems.add(newItem);
}
SortingFilter filter = new SortingFilter(result, node.getBufferManager(), node.getConnectionID(), aggSymbol.isDistinct());
filter.setElements(schema);
filter.setSortItems(orderByItems);
result = filter;
} else if(aggSymbol.isDistinct()) {
SortingFilter filter = new SortingFilter(result, node.getBufferManager(), node.getConnectionID(), true);
List<ElementSymbol> elements = createSortSchema(result, inputTypes);
filter.setElements(elements);
result = filter;
}
result.setArgIndexes(argIndexes);
if (aggSymbol.getCondition() != null) {
result.setConditionIndex(getIndex(aggSymbol.getCondition(), expressionIndexes));
}
result.initialize(aggSymbol.getType(), inputTypes);
return result;
}
private static List<ElementSymbol> createSortSchema(AggregateFunction af,
Class<?>[] inputTypes) {
List<ElementSymbol> elements = new ArrayList<ElementSymbol>(inputTypes.length);
int[] filteredArgIndexes = new int[inputTypes.length];
for (int i = 0; i < inputTypes.length; i++) {
ElementSymbol element = new ElementSymbol("val" + i); //$NON-NLS-1$
element.setType(inputTypes[i]);
elements.add(element);
filteredArgIndexes[i] = i;
}
af.setArgIndexes(filteredArgIndexes);
return elements;
}
AggregateFunction[][] getFunctions() {
return functions;
}
public TupleBatch nextBatchDirect()
throws BlockedException, TeiidComponentException, TeiidProcessingException {
// Take inputs, evaluate expressions, and build initial tuple source
if(this.phase == COLLECTION) {
collectionPhase();
}
// If necessary, sort to determine groups (if no group cols, no need to sort)
if(this.phase == SORT) {
sortPhase();
}
// Walk through the sorted results and for each group, emit a row
if(this.phase == GROUP) {
return groupPhase();
}
if (this.phase == GROUP_SORT) {
groupSortPhase();
}
if (this.phase == GROUP_SORT_OUTPUT) {
return groupSortOutputPhase();
}
this.terminateBatches();
return pullBatch();
}
public TupleSource getGroupSortTupleSource() {
final RelationalNode sourceNode = this.getChildren()[0];
return new ProjectingTupleSource(sourceNode, getEvaluator(elementMap), new ArrayList<Expression>(collectedExpressions.keySet()), elementMap);
}
@Override
public Collection<? extends LanguageObject> getObjects() {
return this.getChildren()[0].getOutputElements();
}
private void collectionPhase() {
if(this.orderBy == null) {
// No need to sort
this.groupTupleSource = getGroupSortTupleSource();
this.phase = GROUP;
} else {
List<NullOrdering> nullOrdering = new ArrayList<NullOrdering>(orderBy.size());
List<Boolean> sortTypes = new ArrayList<Boolean>(orderBy.size());
int size = orderBy.size();
if (this.removeDuplicates) {
//sort on all inputs
size = distinctCols;
}
int[] sortIndexes = new int[size];
for (int i = 0; i < size; i++) {
int index = i;
if (i < this.orderBy.size()) {
OrderByItem item = this.orderBy.get(i);
nullOrdering.add(item.getNullOrdering());
sortTypes.add(item.isAscending());
index = collectedExpressions.get(SymbolMap.getExpression(item.getSymbol()));
} else {
nullOrdering.add(null);
sortTypes.add(OrderBy.ASC);
}
sortIndexes[i] = index;
}
this.indexes = Arrays.copyOf(sortIndexes, orderBy.size());
if (rollup) {
this.indexMap = new HashMap<Integer, Integer>();
for (int i = 0; i < indexes.length; i++) {
this.indexMap.put(indexes[i], orderBy.size() - i);
}
} else if (!removeDuplicates) {
boolean groupSort = true;
List<AggregateFunction> aggs = new ArrayList<AggregateFunction>();
List<Class<?>> allTypes = new ArrayList<Class<?>>();
accumulatorStateCount = new int[this.functions.length];
for (AggregateFunction[] afs : this.functions) {
if (afs[0] instanceof ConstantFunction) {
continue;
}
aggs.add(afs[0]);
List<? extends Class<?>> types = afs[0].getStateTypes();
if (types == null) {
groupSort = false;
break;
}
accumulatorStateCount[aggs.size() - 1] = types.size();
allTypes.addAll(types);
}
if (groupSort) {
this.groupSortfunctions = aggs.toArray(new AggregateFunction[aggs.size()]);
List<Expression> schema = new ArrayList<Expression>();
for (OrderByItem item : this.orderBy) {
schema.add(SymbolMap.getExpression(item.getSymbol()));
}
List<? extends Expression> elements = getElements();
this.projection = new int[elements.size()];
int index = 0;
for (int i = 0; i < elements.size(); i++) {
Expression symbol = elements.get(i);
if (this.outputMapping != null) {
symbol = outputMapping.getMappedExpression((ElementSymbol)symbol);
}
if (symbol instanceof AggregateSymbol) {
projection[i] = schema.size() + index++;
} else {
projection[i] = schema.indexOf(symbol);
}
}
//add in accumulator value types
for (Class<?> type : allTypes) {
ElementSymbol es = new ElementSymbol("x");
es.setType(type);
schema.add(es);
}
tree = this.getBufferManager().createSTree(schema, this.getConnectionID(), orderBy.size());
//non-default order needs to update the comparator
tree.getComparator().setNullOrdering(nullOrdering);
tree.getComparator().setOrderTypes(sortTypes);
this.groupSortTupleSource = this.getGroupSortTupleSource();
this.phase = GROUP_SORT;
return;
}
}
this.sortUtility = new SortUtility(getGroupSortTupleSource(), removeDuplicates?Mode.DUP_REMOVE_SORT:Mode.SORT, getBufferManager(),
getConnectionID(), new ArrayList<Expression>(collectedExpressions.keySet()), sortTypes, nullOrdering, sortIndexes);
this.phase = SORT;
}
}
/**
* Process the input and store the partial accumulator values
* @throws TeiidComponentException
* @throws TeiidProcessingException
*/
private void groupSortPhase() throws TeiidComponentException, TeiidProcessingException {
List<?> tuple = null;
while ((tuple = groupSortTupleSource.nextTuple()) != null) {
List<?> current = tree.find(tuple);
boolean update = false;
List<Object> accumulated = new ArrayList<Object>();
//not all collected expressions are needed for the key
for (int i = 0; i < orderBy.size(); i++) {
accumulated.add(tuple.get(i));
}
if (current != null) {
update = true;
}
int index = orderBy.size();
for (int i = 0; i < this.groupSortfunctions.length; i++) {
AggregateFunction aggregateFunction = this.groupSortfunctions[i];
if (update) {
aggregateFunction.setState(current, index);
} else {
aggregateFunction.reset();
}
index+=this.accumulatorStateCount[i];
aggregateFunction.addInput(tuple, getContext());
aggregateFunction.getState(accumulated);
}
tree.insert(accumulated, update?InsertMode.UPDATE:InsertMode.NEW, -1);
}
this.groupSortTupleSource.closeSource();
this.groupSortTupleSource = tree.getTupleSource(true);
this.phase = GROUP_SORT_OUTPUT;
}
/**
* Walk the tree to produce the results
* @return
* @throws FunctionExecutionException
* @throws ExpressionEvaluationException
* @throws TeiidComponentException
* @throws TeiidProcessingException
*/
private TupleBatch groupSortOutputPhase() throws FunctionExecutionException, ExpressionEvaluationException, TeiidComponentException, TeiidProcessingException {
List<?> tuple = null;
int size = orderBy.size();
List<Object> vals = Arrays.asList(new Object[size + groupSortfunctions.length]);
while ((tuple = groupSortTupleSource.nextTuple()) != null) {
for (int i = 0; i < size; i++) {
vals.set(i, tuple.get(i));
}
int index = size;
for (int i = 0; i < this.groupSortfunctions.length; i++) {
AggregateFunction aggregateFunction = this.groupSortfunctions[i];
aggregateFunction.setState(tuple, index);
index+=this.accumulatorStateCount[i];
vals.set(size + i, aggregateFunction.getResult(getContext()));
}
List<?> result = RelationalNode.projectTuple(projection, vals);
addBatchRow(result);
if (isBatchFull()) {
return pullBatch();
}
}
terminateBatches();
return pullBatch();
}
private void sortPhase() throws BlockedException, TeiidComponentException, TeiidProcessingException {
this.sortBuffer = this.sortUtility.sort();
this.sortBuffer.setForwardOnly(true);
this.groupTupleSource = this.sortBuffer.createIndexedTupleSource();
this.phase = GROUP;
}
private TupleBatch groupPhase() throws BlockedException, TeiidComponentException, TeiidProcessingException {
CommandContext context = getContext();
while(true) {
if (currentGroupTuple == null) {
currentGroupTuple = this.groupTupleSource.nextTuple();
if (currentGroupTuple == null) {
break;
}
}
if(lastRow == null) {
// First row we've seen
lastRow = currentGroupTuple;
} else {
int colDiff = sameGroup(indexes, currentGroupTuple, lastRow);
if (colDiff != -1) {
// Close old group
closeGroup(colDiff, true, context);
// Reset last tuple
lastRow = currentGroupTuple;
// Save in output batch
if (this.isBatchFull()) {
return pullBatch();
}
}
}
// Update function accumulators with new row - can throw blocked exception
updateAggregates(currentGroupTuple);
currentGroupTuple = null;
}
if(lastRow != null || orderBy == null) {
// Close last group
closeGroup(-1, false, context);
}
this.terminateBatches();
return pullBatch();
}
private void closeGroup(int colDiff, boolean reset, CommandContext context) throws FunctionExecutionException,
ExpressionEvaluationException, TeiidComponentException,
TeiidProcessingException {
List<Object> row = new ArrayList<Object>(functions.length);
for(int i=0; i<functions.length; i++) {
row.add( functions[i][0].getResult(context) );
if (reset && !rollup) {
functions[i][0].reset();
}
}
addBatchRow(row);
if (rollup) {
int rollups = orderBy.size() - colDiff;
for (int j = 1; j < rollups; j++) {
row = new ArrayList<Object>(functions.length);
for(int i=0; i<functions.length; i++) {
if (functions[i].length == 1) {
int index = functions[i][0].getArgIndexes()[0];
Integer val = this.indexMap.get(index);
if (val != null && val <= j) {
row.add(null);
} else {
row.add(functions[i][0].getResult(context));
}
} else {
row.add( functions[i][j].getResult(context) );
if (reset) {
functions[i][j].reset();
}
}
}
addBatchRow(row);
}
if (reset) {
for(int i=0; i<functions.length; i++) {
functions[i][0].reset();
}
}
}
}
public static int sameGroup(int[] indexes, List<?> newTuple, List<?> oldTuple) {
if (indexes == null) {
return -1;
}
return MergeJoinStrategy.compareTuples(newTuple, oldTuple, indexes, indexes, true, true);
}
private void updateAggregates(List<?> tuple)
throws TeiidComponentException, TeiidProcessingException {
for(int i=0; i<functions.length; i++) {
for (AggregateFunction function : functions[i]) {
function.addInput(tuple, getContext());
}
}
}
public void closeDirect() {
if (this.sortBuffer != null) {
this.sortBuffer.remove();
this.sortBuffer = null;
}
if (this.sortUtility != null) {
this.sortUtility.remove();
this.sortUtility = null;
}
if (this.tree != null) {
this.tree.remove();
this.tree = null;
}
}
protected void getNodeString(StringBuffer str) {
super.getNodeString(str);
str.append(orderBy);
if (outputMapping != null) {
str.append(outputMapping);
}
}
public Object clone(){
GroupingNode clonedNode = new GroupingNode(super.getID());
super.copyTo(clonedNode);
clonedNode.removeDuplicates = removeDuplicates;
clonedNode.outputMapping = outputMapping;
clonedNode.orderBy = orderBy;
clonedNode.rollup = rollup;
return clonedNode;
}
public PlanNode getDescriptionProperties() {
// Default implementation - should be overridden
PlanNode props = super.getDescriptionProperties();
if(orderBy != null) {
int elements = orderBy.size();
List<String> groupCols = new ArrayList<String>(elements);
for(int i=0; i<elements; i++) {
groupCols.add(this.orderBy.get(i).toString());
}
props.addProperty(PROP_GROUP_COLS, groupCols);
}
if (outputMapping != null) {
List<String> groupCols = new ArrayList<String>(outputMapping.asMap().size());
for(Map.Entry<ElementSymbol, Expression> entry : outputMapping.asMap().entrySet()) {
groupCols.add(entry.toString());
}
props.addProperty(PROP_GROUP_MAPPING, groupCols);
}
props.addProperty(PROP_SORT_MODE, String.valueOf(this.removeDuplicates));
if (rollup) {
props.addProperty(PROP_ROLLUP, Boolean.TRUE.toString());
}
return props;
}
public void setRollup(boolean rollup) {
this.rollup = rollup;
}
}