/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.physical;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.util.BitSets;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.ValueExpressions;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.exec.planner.common.DrillAggregateRelBase;
import org.apache.drill.exec.planner.physical.visitor.PrelVisitor;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel {
protected static enum OperatorPhase {PHASE_1of1, PHASE_1of2, PHASE_2of2};
protected OperatorPhase operPhase = OperatorPhase.PHASE_1of1 ; // default phase
protected List<NamedExpression> keys = Lists.newArrayList();
protected List<NamedExpression> aggExprs = Lists.newArrayList();
protected List<AggregateCall> phase2AggCallList = Lists.newArrayList();
/**
* Specialized aggregate function for SUMing the COUNTs. Since return type of
* COUNT is non-nullable and return type of SUM is nullable, this class enables
* creating a SUM whose return type is non-nullable.
*
*/
public class SqlSumCountAggFunction extends SqlAggFunction {
private final RelDataType type;
public SqlSumCountAggFunction(RelDataType type) {
super("$SUM0",
SqlKind.OTHER_FUNCTION,
ReturnTypes.BIGINT, // use the inferred return type of SqlCountAggFunction
null,
OperandTypes.NUMERIC,
SqlFunctionCategory.NUMERIC);
this.type = type;
}
public List<RelDataType> getParameterTypes(RelDataTypeFactory typeFactory) {
return ImmutableList.of(type);
}
public RelDataType getType() {
return type;
}
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
return type;
}
}
public AggPrelBase(RelOptCluster cluster,
RelTraitSet traits,
RelNode child,
boolean indicator,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls,
OperatorPhase phase) throws InvalidRelException {
super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls);
this.operPhase = phase;
createKeysAndExprs();
}
public OperatorPhase getOperatorPhase() {
return operPhase;
}
public List<NamedExpression> getKeys() {
return keys;
}
public List<NamedExpression> getAggExprs() {
return aggExprs;
}
public List<AggregateCall> getPhase2AggCalls() {
return phase2AggCallList;
}
protected void createKeysAndExprs() {
final List<String> childFields = getInput().getRowType().getFieldNames();
final List<String> fields = getRowType().getFieldNames();
for (int group : BitSets.toIter(groupSet)) {
FieldReference fr = FieldReference.getWithQuotedRef(childFields.get(group));
keys.add(new NamedExpression(fr, fr));
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
int aggExprOrdinal = groupSet.cardinality() + aggCall.i;
FieldReference ref = FieldReference.getWithQuotedRef(fields.get(aggExprOrdinal));
LogicalExpression expr = toDrill(aggCall.e, childFields);
NamedExpression ne = new NamedExpression(expr, ref);
aggExprs.add(ne);
if (getOperatorPhase() == OperatorPhase.PHASE_1of2) {
if (aggCall.e.getAggregation().getName().equals("COUNT")) {
// If we are doing a COUNT aggregate in Phase1of2, then in Phase2of2 we should SUM the COUNTs,
SqlAggFunction sumAggFun = new SqlSumCountAggFunction(aggCall.e.getType());
AggregateCall newAggCall =
new AggregateCall(
sumAggFun,
aggCall.e.isDistinct(),
Collections.singletonList(aggExprOrdinal),
aggCall.e.getType(),
aggCall.e.getName());
phase2AggCallList.add(newAggCall);
} else {
AggregateCall newAggCall =
new AggregateCall(
aggCall.e.getAggregation(),
aggCall.e.isDistinct(),
Collections.singletonList(aggExprOrdinal),
aggCall.e.getType(),
aggCall.e.getName());
phase2AggCallList.add(newAggCall);
}
}
}
}
protected LogicalExpression toDrill(AggregateCall call, List<String> fn) {
List<LogicalExpression> args = Lists.newArrayList();
for (Integer i : call.getArgList()) {
args.add(FieldReference.getWithQuotedRef(fn.get(i)));
}
// for count(1).
if (args.isEmpty()) {
args.add(new ValueExpressions.LongExpression(1l));
}
LogicalExpression expr = new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN );
return expr;
}
@Override
public Iterator<Prel> iterator() {
return PrelUtil.iter(getInput());
}
@Override
public <T, X, E extends Throwable> T accept(PrelVisitor<T, X, E> logicalVisitor, X value) throws E {
return logicalVisitor.visitPrel(this, value);
}
@Override
public boolean needsFinalColumnReordering() {
return true;
}
}