/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.operators.base; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.util.CopyingListCollector; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.util.Collector; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @see org.apache.flink.api.common.functions.FlatJoinFunction */ @Internal public class InnerJoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN2, OUT>> extends JoinOperatorBase<IN1, IN2, OUT, FT> { public InnerJoinOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(udf, operatorInfo, keyPositions1, keyPositions2, name); } public InnerJoinOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } public InnerJoinOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); } // -------------------------------------------------------------------------------------------- @SuppressWarnings("unchecked") @Override protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext, ExecutionConfig executionConfig) throws Exception { FlatJoinFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject(); FunctionUtils.setFunctionRuntimeContext(function, runtimeContext); FunctionUtils.openFunction(function, this.parameters); TypeInformation<IN1> leftInformation = getOperatorInfo().getFirstInputType(); TypeInformation<IN2> rightInformation = getOperatorInfo().getSecondInputType(); TypeInformation<OUT> outInformation = getOperatorInfo().getOutputType(); TypeSerializer<IN1> leftSerializer = leftInformation.createSerializer(executionConfig); TypeSerializer<IN2> rightSerializer = rightInformation.createSerializer(executionConfig); TypeComparator<IN1> leftComparator; TypeComparator<IN2> rightComparator; if (leftInformation instanceof AtomicType) { leftComparator = ((AtomicType<IN1>) leftInformation).createComparator(true, executionConfig); } else if (leftInformation instanceof CompositeType) { int[] keyPositions = getKeyColumns(0); boolean[] orders = new boolean[keyPositions.length]; Arrays.fill(orders, true); leftComparator = ((CompositeType<IN1>) leftInformation).createComparator(keyPositions, orders, 0, executionConfig); } else { throw new RuntimeException("Type information for left input of type " + leftInformation.getClass() .getCanonicalName() + " is not supported. Could not generate a comparator."); } if (rightInformation instanceof AtomicType) { rightComparator = ((AtomicType<IN2>) rightInformation).createComparator(true, executionConfig); } else if (rightInformation instanceof CompositeType) { int[] keyPositions = getKeyColumns(1); boolean[] orders = new boolean[keyPositions.length]; Arrays.fill(orders, true); rightComparator = ((CompositeType<IN2>) rightInformation).createComparator(keyPositions, orders, 0, executionConfig); } else { throw new RuntimeException("Type information for right input of type " + rightInformation.getClass() .getCanonicalName() + " is not supported. Could not generate a comparator."); } TypePairComparator<IN1, IN2> pairComparator = new GenericPairComparator<IN1, IN2>(leftComparator, rightComparator); List<OUT> result = new ArrayList<OUT>(); Collector<OUT> collector = new CopyingListCollector<OUT>(result, outInformation.createSerializer(executionConfig)); Map<Integer, List<IN2>> probeTable = new HashMap<Integer, List<IN2>>(); //Build hash table for (IN2 element : inputData2) { List<IN2> list = probeTable.get(rightComparator.hash(element)); if (list == null) { list = new ArrayList<IN2>(); probeTable.put(rightComparator.hash(element), list); } list.add(element); } //Probing for (IN1 left : inputData1) { List<IN2> matchingHashes = probeTable.get(leftComparator.hash(left)); if (matchingHashes != null) { pairComparator.setReference(left); for (IN2 right : matchingHashes) { if (pairComparator.equalToReference(right)) { function.join(leftSerializer.copy(left), rightSerializer.copy(right), collector); } } } } FunctionUtils.closeFunction(function); return result; } }