/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.optimizer.calcite.functions; import java.util.ArrayList; import java.util.List; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlSplittableAggFunction; import org.apache.calcite.sql.SqlSplittableAggFunction.SumSplitter; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableIntList; import com.google.common.collect.ImmutableList; /** * <code>Sum</code> is an aggregator which returns the sum of the values which * go into it. It has precisely one argument of numeric type (<code>int</code>, * <code>long</code>, <code>float</code>, <code>double</code>), and the result * is the same type. */ public class HiveSqlSumAggFunction extends SqlAggFunction implements CanAggregateDistinct{ final boolean isDistinct; final SqlReturnTypeInference returnTypeInference; final SqlOperandTypeInference operandTypeInference; final SqlOperandTypeChecker operandTypeChecker; //~ Constructors ----------------------------------------------------------- public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference, SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) { super( "sum", SqlKind.SUM, returnTypeInference, operandTypeInference, operandTypeChecker, SqlFunctionCategory.NUMERIC); this.returnTypeInference = returnTypeInference; this.operandTypeChecker = operandTypeChecker; this.operandTypeInference = operandTypeInference; this.isDistinct = isDistinct; } //~ Methods ---------------------------------------------------------------- @Override public boolean isDistinct() { return isDistinct; } @Override public <T> T unwrap(Class<T> clazz) { if (clazz == SqlSplittableAggFunction.class) { return clazz.cast(new HiveSumSplitter()); } return super.unwrap(clazz); } class HiveSumSplitter extends SumSplitter { @Override public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true); return AggregateCall.create( new HiveSqlCountAggFunction(isDistinct, ReturnTypes.explicit(countRetType), operandTypeInference, operandTypeChecker), false, ImmutableIntList.of(), -1, countRetType, "count"); } @Override public AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra, int offset, RelDataType inputRowType, AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) { final List<RexNode> merges = new ArrayList<>(); final List<RelDataTypeField> fieldList = inputRowType.getFieldList(); if (leftSubTotal >= 0) { final RelDataType type = fieldList.get(leftSubTotal).getType(); merges.add(rexBuilder.makeInputRef(type, leftSubTotal)); } if (rightSubTotal >= 0) { final RelDataType type = fieldList.get(rightSubTotal).getType(); merges.add(rexBuilder.makeInputRef(type, rightSubTotal)); } RexNode node; switch (merges.size()) { case 1: node = merges.get(0); break; case 2: node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges); node = rexBuilder.makeAbstractCast(aggregateCall.type, node); break; default: throw new AssertionError("unexpected count " + merges); } int ordinal = extra.register(node); return AggregateCall.create(new HiveSqlSumAggFunction(isDistinct, returnTypeInference, operandTypeInference, operandTypeChecker), false, ImmutableList.of(ordinal), -1, aggregateCall.type, aggregateCall.name); } } } // End SqlSumAggFunction.java