/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.operators.base; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.util.CopyingListCollector; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.Ordering; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.util.Collector; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.List; import org.apache.flink.api.common.ExecutionConfig; /** * @see org.apache.flink.api.common.functions.CoGroupFunction */ @Internal public class CoGroupRawOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1, IN2, OUT>> extends DualInputOperator<IN1, IN2, OUT, FT> { /** * The ordering for the order inside a group from input one. */ private Ordering groupOrder1; /** * The ordering for the order inside a group from input two. */ private Ordering groupOrder2; // -------------------------------------------------------------------------------------------- private boolean combinableFirst; private boolean combinableSecond; public CoGroupRawOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(udf, operatorInfo, keyPositions1, keyPositions2, name); this.combinableFirst = false; this.combinableSecond = false; } public CoGroupRawOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { this(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } public CoGroupRawOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { this(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } // -------------------------------------------------------------------------------------------- /** * Sets the order of the elements within a group for the given input. * * @param inputNum The number of the input (here either <i>0</i> or <i>1</i>). * @param order The order for the elements in a group. */ public void setGroupOrder(int inputNum, Ordering order) { if (inputNum == 0) { this.groupOrder1 = order; } else if (inputNum == 1) { this.groupOrder2 = order; } else { throw new IndexOutOfBoundsException(); } } /** * Sets the order of the elements within a group for the first input. * * @param order The order for the elements in a group. */ public void setGroupOrderForInputOne(Ordering order) { setGroupOrder(0, order); } /** * Sets the order of the elements within a group for the second input. * * @param order The order for the elements in a group. */ public void setGroupOrderForInputTwo(Ordering order) { setGroupOrder(1, order); } /** * Gets the value order for an input, i.e. the order of elements within a group. * If no such order has been set, this method returns null. * * @param inputNum The number of the input (here either <i>0</i> or <i>1</i>). * @return The group order. */ public Ordering getGroupOrder(int inputNum) { if (inputNum == 0) { return this.groupOrder1; } else if (inputNum == 1) { return this.groupOrder2; } else { throw new IndexOutOfBoundsException(); } } /** * Gets the order of elements within a group for the first input. * If no such order has been set, this method returns null. * * @return The group order for the first input. */ public Ordering getGroupOrderForInputOne() { return getGroupOrder(0); } /** * Gets the order of elements within a group for the second input. * If no such order has been set, this method returns null. * * @return The group order for the second input. */ public Ordering getGroupOrderForInputTwo() { return getGroupOrder(1); } // -------------------------------------------------------------------------------------------- public boolean isCombinableFirst() { return this.combinableFirst; } public void setCombinableFirst(boolean combinableFirst) { this.combinableFirst = combinableFirst; } public boolean isCombinableSecond() { return this.combinableSecond; } public void setCombinableSecond(boolean combinableSecond) { this.combinableSecond = combinableSecond; } // ------------------------------------------------------------------------ @Override protected List<OUT> executeOnCollections(List<IN1> input1, List<IN2> input2, RuntimeContext ctx, ExecutionConfig executionConfig) throws Exception { // -------------------------------------------------------------------- // Setup // -------------------------------------------------------------------- TypeInformation<IN1> inputType1 = getOperatorInfo().getFirstInputType(); TypeInformation<IN2> inputType2 = getOperatorInfo().getSecondInputType(); int[] inputKeys1 = getKeyColumns(0); int[] inputKeys2 = getKeyColumns(1); boolean[] inputSortDirections1 = new boolean[inputKeys1.length]; boolean[] inputSortDirections2 = new boolean[inputKeys2.length]; Arrays.fill(inputSortDirections1, true); Arrays.fill(inputSortDirections2, true); final TypeSerializer<IN1> inputSerializer1 = inputType1.createSerializer(executionConfig); final TypeSerializer<IN2> inputSerializer2 = inputType2.createSerializer(executionConfig); final TypeComparator<IN1> inputComparator1 = getTypeComparator(executionConfig, inputType1, inputKeys1, inputSortDirections1); final TypeComparator<IN2> inputComparator2 = getTypeComparator(executionConfig, inputType2, inputKeys2, inputSortDirections2); SimpleListIterable<IN1> iterator1 = new SimpleListIterable<IN1>(input1, inputComparator1, inputSerializer1); SimpleListIterable<IN2> iterator2 = new SimpleListIterable<IN2>(input2, inputComparator2, inputSerializer2); // -------------------------------------------------------------------- // Run UDF // -------------------------------------------------------------------- CoGroupFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject(); FunctionUtils.setFunctionRuntimeContext(function, ctx); FunctionUtils.openFunction(function, parameters); List<OUT> result = new ArrayList<OUT>(); Collector<OUT> resultCollector = new CopyingListCollector<OUT>(result, getOperatorInfo().getOutputType().createSerializer(executionConfig)); function.coGroup(iterator1, iterator2, resultCollector); FunctionUtils.closeFunction(function); return result; } private <T> TypeComparator<T> getTypeComparator(ExecutionConfig executionConfig, TypeInformation<T> inputType, int[] inputKeys, boolean[] inputSortDirections) { if (!(inputType instanceof CompositeType)) { throw new InvalidProgramException("Input types of coGroup must be composite types."); } return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); } public static class SimpleListIterable<IN> implements Iterable<IN> { private List<IN> values; private TypeSerializer<IN> serializer; private boolean copy; public SimpleListIterable(List<IN> values, final TypeComparator<IN> comparator, TypeSerializer<IN> serializer) throws IOException { this.values = values; this.serializer = serializer; Collections.sort(values, new Comparator<IN>() { @Override public int compare(IN o1, IN o2) { return comparator.compare(o1, o2); } }); } @Override public Iterator<IN> iterator() { return new SimpleListIterator<IN>(values, serializer); } protected class SimpleListIterator<IN> implements Iterator<IN> { private final List<IN> values; private final TypeSerializer<IN> serializer; private int pos = 0; public SimpleListIterator(List<IN> values, TypeSerializer<IN> serializer) { this.values = values; this.serializer = serializer; } @Override public boolean hasNext() { return pos < values.size(); } @Override public IN next() { IN current = values.get(pos++); return serializer.copy(current); } @Override public void remove() { //unused } } } }