/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
import com.google.common.collect.Sets;
public class HiveAggregate extends Aggregate implements HiveRelNode {
private LinkedHashSet<Integer> aggregateColumnsOrder;
public HiveAggregate(RelOptCluster cluster, RelTraitSet traitSet, RelNode child,
boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
super(cluster, TraitsUtil.getDefaultTraitSet(cluster), child, indicator, groupSet,
groupSets, aggCalls);
}
@Override
public Aggregate copy(RelTraitSet traitSet, RelNode input,
boolean indicator, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
return new HiveAggregate(getCluster(), traitSet, input, indicator, groupSet,
groupSets, aggCalls);
}
@Override
public void implement(Implementor implementor) {
}
// getRows will call estimateRowCount
@Override
public double estimateRowCount(RelMetadataQuery mq) {
return mq.getDistinctRowCount(this, groupSet, getCluster().getRexBuilder().makeLiteral(true));
}
public boolean isBucketedInput() {
return RelMetadataQuery.instance().distribution(this.getInput()).getKeys().
containsAll(groupSet.asList());
}
@Override
protected RelDataType deriveRowType() {
return deriveRowType(getCluster().getTypeFactory(), getInput().getRowType(),
indicator, groupSet, groupSets, aggCalls);
}
public static RelDataType deriveRowType(RelDataTypeFactory typeFactory,
final RelDataType inputRowType, boolean indicator,
ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
final List<AggregateCall> aggCalls) {
final List<Integer> groupList = groupSet.asList();
assert groupList.size() == groupSet.cardinality();
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
final Set<String> containedNames = Sets.newHashSet();
for (int groupKey : groupList) {
containedNames.add(fieldList.get(groupKey).getName());
builder.add(fieldList.get(groupKey));
}
if (indicator) {
for (int groupKey : groupList) {
final RelDataType booleanType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
String name = "i$" + fieldList.get(groupKey).getName();
int i = 0;
while (containedNames.contains(name)) {
name += "_" + i++;
}
containedNames.add(name);
builder.add(name, booleanType);
}
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
String name;
if (aggCall.e.name != null) {
name = aggCall.e.name;
} else {
name = "$f" + (groupList.size() + aggCall.i);
}
int i = 0;
while (containedNames.contains(name)) {
name += "_" + i++;
}
containedNames.add(name);
builder.add(name, aggCall.e.type);
}
return builder.build();
}
public void setAggregateColumnsOrder(LinkedHashSet<Integer> aggregateColumnsOrder) {
this.aggregateColumnsOrder = aggregateColumnsOrder;
}
public LinkedHashSet<Integer> getAggregateColumnsOrder() {
return this.aggregateColumnsOrder;
}
}