/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.java.operators; import java.util.ArrayList; import java.util.List; import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.Public; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.operators.Operator; import org.apache.flink.api.common.operators.SingleInputSemanticProperties; import org.apache.flink.api.common.operators.UnaryOperatorInformation; import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.aggregation.AggregationFunction; import org.apache.flink.api.java.aggregation.AggregationFunctionFactory; import org.apache.flink.api.java.aggregation.Aggregations; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; /** * This operator represents the application of a "aggregate" operation on a data set, and the * result data set produced by the function. * * @param <IN> The type of the data set aggregated by the operator. */ @Public public class AggregateOperator<IN> extends SingleInputOperator<IN, IN, AggregateOperator<IN>> { private final List<AggregationFunction<?>> aggregationFunctions = new ArrayList<>(4); private final List<Integer> fields = new ArrayList<>(4); private final Grouping<IN> grouping; private final String aggregateLocationName; /** * <p> * Non grouped aggregation */ public AggregateOperator(DataSet<IN> input, Aggregations function, int field, String aggregateLocationName) { super(Preconditions.checkNotNull(input), input.getType()); Preconditions.checkNotNull(function); this.aggregateLocationName = aggregateLocationName; if (!input.getType().isTupleType()) { throw new InvalidProgramException("Aggregating on field positions is only possible on tuple data types."); } TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getType(); if (field < 0 || field >= inType.getArity()) { throw new IllegalArgumentException("Aggregation field position is out of range."); } AggregationFunctionFactory factory = function.getFactory(); AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass()); // this is the first aggregation operator after a regular data set (non grouped aggregation) this.aggregationFunctions.add(aggFunct); this.fields.add(field); this.grouping = null; } /** * * Grouped aggregation * * @param input * @param function * @param field */ public AggregateOperator(Grouping<IN> input, Aggregations function, int field, String aggregateLocationName) { super(Preconditions.checkNotNull(input).getInputDataSet(), input.getInputDataSet().getType()); Preconditions.checkNotNull(function); this.aggregateLocationName = aggregateLocationName; if (!input.getInputDataSet().getType().isTupleType()) { throw new InvalidProgramException("Aggregating on field positions is only possible on tuple data types."); } TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getInputDataSet().getType(); if (field < 0 || field >= inType.getArity()) { throw new IllegalArgumentException("Aggregation field position is out of range."); } AggregationFunctionFactory factory = function.getFactory(); AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass()); // set the aggregation fields this.aggregationFunctions.add(aggFunct); this.fields.add(field); this.grouping = input; } public AggregateOperator<IN> and(Aggregations function, int field) { Preconditions.checkNotNull(function); TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) getType(); if (field < 0 || field >= inType.getArity()) { throw new IllegalArgumentException("Aggregation field position is out of range."); } AggregationFunctionFactory factory = function.getFactory(); AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass()); this.aggregationFunctions.add(aggFunct); this.fields.add(field); return this; } public AggregateOperator<IN> andSum (int field) { return this.and(Aggregations.SUM, field); } public AggregateOperator<IN> andMin (int field) { return this.and(Aggregations.MIN, field); } public AggregateOperator<IN> andMax (int field) { return this.and(Aggregations.MAX, field); } @SuppressWarnings("unchecked") @Override @Internal protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> translateToDataFlow(Operator<IN> input) { // sanity check if (this.aggregationFunctions.isEmpty() || this.aggregationFunctions.size() != this.fields.size()) { throw new IllegalStateException(); } // construct the aggregation function AggregationFunction<Object>[] aggFunctions = new AggregationFunction[this.aggregationFunctions.size()]; int[] fields = new int[this.fields.size()]; StringBuilder genName = new StringBuilder(); for (int i = 0; i < fields.length; i++) { aggFunctions[i] = (AggregationFunction<Object>) this.aggregationFunctions.get(i); fields[i] = this.fields.get(i); genName.append(aggFunctions[i].toString()).append('(').append(fields[i]).append(')').append(','); } genName.append(" at ").append(aggregateLocationName); genName.setLength(genName.length()-1); @SuppressWarnings("rawtypes") RichGroupReduceFunction<IN, IN> function = new AggregatingUdf(aggFunctions, fields); String name = getName() != null ? getName() : genName.toString(); // distinguish between grouped reduce and non-grouped reduce if (this.grouping == null) { // non grouped aggregation UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType()); GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, new int[0], name); po.setCombinable(true); // set input po.setInput(input); // set parallelism po.setParallelism(this.getParallelism()); return po; } if (this.grouping.getKeys() instanceof Keys.ExpressionKeys) { // grouped aggregation int[] logicalKeyPositions = this.grouping.getKeys().computeLogicalKeyPositions(); UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType()); GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, logicalKeyPositions, name); po.setCombinable(true); po.setInput(input); po.setParallelism(this.getParallelism()); po.setCustomPartitioner(grouping.getCustomPartitioner()); SingleInputSemanticProperties props = new SingleInputSemanticProperties(); for (int keyField : logicalKeyPositions) { boolean keyFieldUsedInAgg = false; for (int aggField : fields) { if (keyField == aggField) { keyFieldUsedInAgg = true; break; } } if (!keyFieldUsedInAgg) { props.addForwardedField(keyField, keyField); } } po.setSemanticProperties(props); return po; } else if (this.grouping.getKeys() instanceof Keys.SelectorFunctionKeys) { throw new UnsupportedOperationException("Aggregate does not support grouping with KeySelector functions, yet."); } else { throw new UnsupportedOperationException("Unrecognized key type."); } } // -------------------------------------------------------------------------------------------- @Internal public static final class AggregatingUdf<T extends Tuple> extends RichGroupReduceFunction<T, T> implements GroupCombineFunction<T, T> { private static final long serialVersionUID = 1L; private final int[] fieldPositions; private final AggregationFunction<Object>[] aggFunctions; public AggregatingUdf(AggregationFunction<Object>[] aggFunctions, int[] fieldPositions) { Preconditions.checkNotNull(aggFunctions); Preconditions.checkNotNull(aggFunctions); Preconditions.checkArgument(aggFunctions.length == fieldPositions.length); this.aggFunctions = aggFunctions; this.fieldPositions = fieldPositions; } @Override public void open(Configuration parameters) throws Exception { for (AggregationFunction<Object> aggFunction : aggFunctions) { aggFunction.initializeAggregate(); } } @Override public void reduce(Iterable<T> records, Collector<T> out) { final AggregationFunction<Object>[] aggFunctions = this.aggFunctions; final int[] fieldPositions = this.fieldPositions; // aggregators are initialized from before T outT = null; for (T record : records) { outT = record; for (int i = 0; i < fieldPositions.length; i++) { Object val = record.getFieldNotNull(fieldPositions[i]); aggFunctions[i].aggregate(val); } } for (int i = 0; i < fieldPositions.length; i++) { Object aggVal = aggFunctions[i].getAggregate(); outT.setField(aggVal, fieldPositions[i]); aggFunctions[i].initializeAggregate(); } out.collect(outT); } @Override public void combine(Iterable<T> records, Collector<T> out) { reduce(records, out); } } }