/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.operators.base; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.util.CopyingListCollector; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.Ordering; import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.util.Collector; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * @see org.apache.flink.api.common.functions.CoGroupFunction */ @Internal public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1, IN2, OUT>> extends DualInputOperator<IN1, IN2, OUT, FT> { /** The ordering for the order inside a group from input one. */ private Ordering groupOrder1; /** The ordering for the order inside a group from input two. */ private Ordering groupOrder2; private Partitioner<?> customPartitioner; private boolean combinableFirst; private boolean combinableSecond; // -------------------------------------------------------------------------------------------- public CoGroupOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(udf, operatorInfo, keyPositions1, keyPositions2, name); this.combinableFirst = false; this.combinableSecond = false; } public CoGroupOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { this(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } public CoGroupOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { this(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } // -------------------------------------------------------------------------------------------- /** * Sets the order of the elements within a group for the given input. * * @param inputNum The number of the input (here either <i>0</i> or <i>1</i>). * @param order The order for the elements in a group. */ public void setGroupOrder(int inputNum, Ordering order) { if (inputNum == 0) { this.groupOrder1 = order; } else if (inputNum == 1) { this.groupOrder2 = order; } else { throw new IndexOutOfBoundsException(); } } /** * Sets the order of the elements within a group for the first input. * * @param order The order for the elements in a group. */ public void setGroupOrderForInputOne(Ordering order) { setGroupOrder(0, order); } /** * Sets the order of the elements within a group for the second input. * * @param order The order for the elements in a group. */ public void setGroupOrderForInputTwo(Ordering order) { setGroupOrder(1, order); } /** * Gets the value order for an input, i.e. the order of elements within a group. * If no such order has been set, this method returns null. * * @param inputNum The number of the input (here either <i>0</i> or <i>1</i>). * @return The group order. */ public Ordering getGroupOrder(int inputNum) { if (inputNum == 0) { return this.groupOrder1; } else if (inputNum == 1) { return this.groupOrder2; } else { throw new IndexOutOfBoundsException(); } } /** * Gets the order of elements within a group for the first input. * If no such order has been set, this method returns null. * * @return The group order for the first input. */ public Ordering getGroupOrderForInputOne() { return getGroupOrder(0); } /** * Gets the order of elements within a group for the second input. * If no such order has been set, this method returns null. * * @return The group order for the second input. */ public Ordering getGroupOrderForInputTwo() { return getGroupOrder(1); } // -------------------------------------------------------------------------------------------- public boolean isCombinableFirst() { return this.combinableFirst; } public void setCombinableFirst(boolean combinableFirst) { this.combinableFirst = combinableFirst; } public boolean isCombinableSecond() { return this.combinableSecond; } public void setCombinableSecond(boolean combinableSecond) { this.combinableSecond = combinableSecond; } public void setCustomPartitioner(Partitioner<?> customPartitioner) { this.customPartitioner = customPartitioner; } public Partitioner<?> getCustomPartitioner() { return customPartitioner; } // ------------------------------------------------------------------------ @Override protected List<OUT> executeOnCollections(List<IN1> input1, List<IN2> input2, RuntimeContext ctx, ExecutionConfig executionConfig) throws Exception { // -------------------------------------------------------------------- // Setup // -------------------------------------------------------------------- TypeInformation<IN1> inputType1 = getOperatorInfo().getFirstInputType(); TypeInformation<IN2> inputType2 = getOperatorInfo().getSecondInputType(); // for the grouping / merging comparator int[] inputKeys1 = getKeyColumns(0); int[] inputKeys2 = getKeyColumns(1); boolean[] inputDirections1 = new boolean[inputKeys1.length]; boolean[] inputDirections2 = new boolean[inputKeys2.length]; Arrays.fill(inputDirections1, true); Arrays.fill(inputDirections2, true); final TypeSerializer<IN1> inputSerializer1 = inputType1.createSerializer(executionConfig); final TypeSerializer<IN2> inputSerializer2 = inputType2.createSerializer(executionConfig); final TypeComparator<IN1> inputComparator1 = getTypeComparator(executionConfig, inputType1, inputKeys1, inputDirections1); final TypeComparator<IN2> inputComparator2 = getTypeComparator(executionConfig, inputType2, inputKeys2, inputDirections2); final TypeComparator<IN1> inputSortComparator1; final TypeComparator<IN2> inputSortComparator2; if (groupOrder1 == null || groupOrder1.getNumberOfFields() == 0) { // no group sorting inputSortComparator1 = inputComparator1; } else { // group sorting int[] groupSortKeys = groupOrder1.getFieldPositions(); int[] allSortKeys = new int[inputKeys1.length + groupOrder1.getNumberOfFields()]; System.arraycopy(inputKeys1, 0, allSortKeys, 0, inputKeys1.length); System.arraycopy(groupSortKeys, 0, allSortKeys, inputKeys1.length, groupSortKeys.length); boolean[] groupSortDirections = groupOrder1.getFieldSortDirections(); boolean[] allSortDirections = new boolean[inputKeys1.length + groupSortKeys.length]; Arrays.fill(allSortDirections, 0, inputKeys1.length, true); System.arraycopy(groupSortDirections, 0, allSortDirections, inputKeys1.length, groupSortDirections.length); inputSortComparator1 = getTypeComparator(executionConfig, inputType1, allSortKeys, allSortDirections); } if (groupOrder2 == null || groupOrder2.getNumberOfFields() == 0) { // no group sorting inputSortComparator2 = inputComparator2; } else { // group sorting int[] groupSortKeys = groupOrder2.getFieldPositions(); int[] allSortKeys = new int[inputKeys2.length + groupOrder2.getNumberOfFields()]; System.arraycopy(inputKeys2, 0, allSortKeys, 0, inputKeys2.length); System.arraycopy(groupSortKeys, 0, allSortKeys, inputKeys2.length, groupSortKeys.length); boolean[] groupSortDirections = groupOrder2.getFieldSortDirections(); boolean[] allSortDirections = new boolean[inputKeys2.length + groupSortKeys.length]; Arrays.fill(allSortDirections, 0, inputKeys2.length, true); System.arraycopy(groupSortDirections, 0, allSortDirections, inputKeys2.length, groupSortDirections.length); inputSortComparator2 = getTypeComparator(executionConfig, inputType2, allSortKeys, allSortDirections); } CoGroupSortListIterator<IN1, IN2> coGroupIterator = new CoGroupSortListIterator<IN1, IN2>(input1, inputSortComparator1, inputComparator1, inputSerializer1, input2, inputSortComparator2, inputComparator2, inputSerializer2); // -------------------------------------------------------------------- // Run UDF // -------------------------------------------------------------------- CoGroupFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject(); FunctionUtils.setFunctionRuntimeContext(function, ctx); FunctionUtils.openFunction(function, parameters); List<OUT> result = new ArrayList<OUT>(); Collector<OUT> resultCollector = new CopyingListCollector<OUT>(result, getOperatorInfo().getOutputType().createSerializer(executionConfig)); while (coGroupIterator.next()) { function.coGroup(coGroupIterator.getValues1(), coGroupIterator.getValues2(), resultCollector); } FunctionUtils.closeFunction(function); return result; } @SuppressWarnings("unchecked") private <T> TypeComparator<T> getTypeComparator(ExecutionConfig executionConfig, TypeInformation<T> inputType, int[] inputKeys, boolean[] inputSortDirections) { if (inputType instanceof CompositeType) { return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); } else if (inputType instanceof AtomicType) { return ((AtomicType<T>) inputType).createComparator(inputSortDirections[0], executionConfig); } throw new InvalidProgramException("Input type of coGroup must be one of composite types or atomic types."); } private static class CoGroupSortListIterator<IN1, IN2> { private static enum MatchStatus { NONE_REMAINED, FIRST_REMAINED, SECOND_REMAINED, FIRST_EMPTY, SECOND_EMPTY } private final ListKeyGroupedIterator<IN1> iterator1; private final ListKeyGroupedIterator<IN2> iterator2; private final TypePairComparator<IN1, IN2> pairComparator; private MatchStatus matchStatus; private Iterable<IN1> firstReturn; private Iterable<IN2> secondReturn; private CoGroupSortListIterator( List<IN1> input1, final TypeComparator<IN1> inputSortComparator1, TypeComparator<IN1> inputComparator1, TypeSerializer<IN1> serializer1, List<IN2> input2, final TypeComparator<IN2> inputSortComparator2, TypeComparator<IN2> inputComparator2, TypeSerializer<IN2> serializer2) { this.pairComparator = new GenericPairComparator<IN1, IN2>(inputComparator1, inputComparator2); this.iterator1 = new ListKeyGroupedIterator<IN1>(input1, serializer1, inputComparator1); this.iterator2 = new ListKeyGroupedIterator<IN2>(input2, serializer2, inputComparator2); // ---------------------------------------------------------------- // Sort // ---------------------------------------------------------------- Collections.sort(input1, new Comparator<IN1>() { @Override public int compare(IN1 o1, IN1 o2) { return inputSortComparator1.compare(o1, o2); } }); Collections.sort(input2, new Comparator<IN2>() { @Override public int compare(IN2 o1, IN2 o2) { return inputSortComparator2.compare(o1, o2); } }); } private boolean next() throws IOException { boolean firstEmpty = true; boolean secondEmpty = true; if (this.matchStatus != MatchStatus.FIRST_EMPTY) { if (this.matchStatus == MatchStatus.FIRST_REMAINED) { // comparator is still set correctly firstEmpty = false; } else { if (this.iterator1.nextKey()) { this.pairComparator.setReference(iterator1.getValues().getCurrent()); firstEmpty = false; } } } if (this.matchStatus != MatchStatus.SECOND_EMPTY) { if (this.matchStatus == MatchStatus.SECOND_REMAINED) { secondEmpty = false; } else { if (iterator2.nextKey()) { secondEmpty = false; } } } if (firstEmpty && secondEmpty) { // both inputs are empty return false; } else if (firstEmpty && !secondEmpty) { // input1 is empty, input2 not this.firstReturn = Collections.emptySet(); this.secondReturn = this.iterator2.getValues(); this.matchStatus = MatchStatus.FIRST_EMPTY; return true; } else if (!firstEmpty && secondEmpty) { // input1 is not empty, input 2 is empty this.firstReturn = this.iterator1.getValues(); this.secondReturn = Collections.emptySet(); this.matchStatus = MatchStatus.SECOND_EMPTY; return true; } else { // both inputs are not empty final int comp = this.pairComparator.compareToReference(iterator2.getValues().getCurrent()); if (0 == comp) { // keys match this.firstReturn = this.iterator1.getValues(); this.secondReturn = this.iterator2.getValues(); this.matchStatus = MatchStatus.NONE_REMAINED; } else if (0 < comp) { // key1 goes first this.firstReturn = this.iterator1.getValues(); this.secondReturn = Collections.emptySet(); this.matchStatus = MatchStatus.SECOND_REMAINED; } else { // key 2 goes first this.firstReturn = Collections.emptySet(); this.secondReturn = this.iterator2.getValues(); this.matchStatus = MatchStatus.FIRST_REMAINED; } return true; } } private Iterable<IN1> getValues1() { return firstReturn; } private Iterable<IN2> getValues2() { return secondReturn; } } }