/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.TypeComparable;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Base data flow operator for Reduce user-defined functions. Accepts reduce functions
* and key positions. The key positions are expected in the flattened common data model.
*
* @see org.apache.flink.api.common.functions.ReduceFunction
*
* @param <T> The type (parameters and return type) of the reduce function.
* @param <FT> The type of the reduce function.
*/
@Internal
public class ReduceOperatorBase<T, FT extends ReduceFunction<T>> extends SingleInputOperator<T, T, FT> {
/**
* An enumeration of hints, optionally usable to tell the system exactly how to execute the combiner phase
* of a reduce.
* (Note: The final reduce phase (after combining) is currently always executed by a sort-based strategy.)
*/
public enum CombineHint {
/**
* Leave the choice how to do the combine to the optimizer. (This currently defaults to SORT.)
*/
OPTIMIZER_CHOOSES,
/**
* Use a sort-based strategy.
*/
SORT,
/**
* Use a hash-based strategy. This should be faster in most cases, especially if the number
* of different keys is small compared to the number of input elements (eg. 1/10).
*/
HASH
}
private CombineHint hint;
private Partitioner<?> customPartitioner;
/**
* Creates a grouped reduce data flow operator.
*
* @param udf The user-defined function, contained in the UserCodeWrapper.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param keyPositions The positions of the key fields, in the common data model (flattened).
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(UserCodeWrapper<FT> udf, UnaryOperatorInformation<T, T> operatorInfo, int[] keyPositions, String name) {
super(udf, operatorInfo, keyPositions, name);
}
/**
* Creates a grouped reduce data flow operator.
*
* @param udf The user-defined function, as a function object.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param keyPositions The positions of the key fields, in the common data model (flattened).
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(FT udf, UnaryOperatorInformation<T, T> operatorInfo, int[] keyPositions, String name) {
super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions, name);
}
/**
* Creates a grouped reduce data flow operator.
*
* @param udf The class representing the parameterless user-defined function.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param keyPositions The positions of the key fields, in the common data model (flattened).
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<T, T> operatorInfo, int[] keyPositions, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions, name);
}
// --------------------------------------------------------------------------------------------
// Non-grouped reduce operations
// --------------------------------------------------------------------------------------------
/**
* Creates a non-grouped reduce data flow operator (all-reduce).
*
* @param udf The user-defined function, contained in the UserCodeWrapper.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(UserCodeWrapper<FT> udf, UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(udf, operatorInfo, name);
}
/**
* Creates a non-grouped reduce data flow operator (all-reduce).
*
* @param udf The user-defined function, as a function object.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(FT udf, UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, name);
}
/**
* Creates a non-grouped reduce data flow operator (all-reduce).
*
* @param udf The class representing the parameterless user-defined function.
* @param operatorInfo The type information, describing input and output types of the reduce function.
* @param name The name of the operator (for logging and messages).
*/
public ReduceOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}
// --------------------------------------------------------------------------------------------
public void setCustomPartitioner(Partitioner<?> customPartitioner) {
if (customPartitioner != null) {
int[] keys = getKeyColumns(0);
if (keys == null || keys.length == 0) {
throw new IllegalArgumentException("Cannot use custom partitioner for a non-grouped GroupReduce (AllGroupReduce)");
}
if (keys.length > 1) {
throw new IllegalArgumentException("Cannot use the key partitioner for composite keys (more than one key field)");
}
}
this.customPartitioner = customPartitioner;
}
public Partitioner<?> getCustomPartitioner() {
return customPartitioner;
}
// --------------------------------------------------------------------------------------------
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx, ExecutionConfig executionConfig) throws Exception {
// make sure we can handle empty inputs
if (inputData.isEmpty()) {
return Collections.emptyList();
}
ReduceFunction<T> function = this.userFunction.getUserCodeObject();
UnaryOperatorInformation<T, T> operatorInfo = getOperatorInfo();
TypeInformation<T> inputType = operatorInfo.getInputType();
int[] inputColumns = getKeyColumns(0);
if (!(inputType instanceof CompositeType) && inputColumns.length > 1) {
throw new InvalidProgramException("Grouping is only possible on composite types.");
}
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
TypeSerializer<T> serializer = getOperatorInfo().getInputType().createSerializer(executionConfig);
if (inputColumns.length > 0) {
boolean[] inputOrderings = new boolean[inputColumns.length];
TypeComparator<T> inputComparator = inputType instanceof AtomicType
? ((AtomicType<T>) inputType).createComparator(false, executionConfig)
: ((CompositeType<T>) inputType).createComparator(inputColumns, inputOrderings, 0, executionConfig);
Map<TypeComparable<T>, T> aggregateMap = new HashMap<TypeComparable<T>, T>(inputData.size() / 10);
for (T next : inputData) {
TypeComparable<T> wrapper = new TypeComparable<T>(next, inputComparator);
T existing = aggregateMap.get(wrapper);
T result;
if (existing != null) {
result = function.reduce(existing, serializer.copy(next));
} else {
result = next;
}
result = serializer.copy(result);
aggregateMap.put(wrapper, result);
}
FunctionUtils.closeFunction(function);
return new ArrayList<T>(aggregateMap.values());
}
else {
T aggregate = inputData.get(0);
aggregate = serializer.copy(aggregate);
for (int i = 1; i < inputData.size(); i++) {
T next = function.reduce(aggregate, serializer.copy(inputData.get(i)));
aggregate = serializer.copy(next);
}
FunctionUtils.setFunctionRuntimeContext(function, ctx);
return Collections.singletonList(aggregate);
}
}
public void setCombineHint(CombineHint hint) {
if (hint == null) {
throw new IllegalArgumentException("Reduce Hint must not be null.");
}
this.hint = hint;
}
public CombineHint getCombineHint() {
return hint;
}
}