/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.operators.base; import org.apache.commons.collections.ResettableIterator; import org.apache.commons.collections.iterators.ListIteratorWrapper; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.util.CopyingListCollector; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.util.Collector; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; @Internal public class OuterJoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN2, OUT>> extends JoinOperatorBase<IN1, IN2, OUT, FT> { public static enum OuterJoinType {LEFT, RIGHT, FULL} private OuterJoinType outerJoinType; public OuterJoinOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name, OuterJoinType outerJoinType) { super(udf, operatorInfo, keyPositions1, keyPositions2, name); this.outerJoinType = outerJoinType; } public OuterJoinOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name, OuterJoinType outerJoinType) { super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); this.outerJoinType = outerJoinType; } public OuterJoinOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name, OuterJoinType outerJoinType) { super(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name); this.outerJoinType = outerJoinType; } public void setOuterJoinType(OuterJoinType outerJoinType) { this.outerJoinType = outerJoinType; } public OuterJoinType getOuterJoinType() { return outerJoinType; } @Override protected List<OUT> executeOnCollections(List<IN1> leftInput, List<IN2> rightInput, RuntimeContext runtimeContext, ExecutionConfig executionConfig) throws Exception { TypeInformation<IN1> leftInformation = getOperatorInfo().getFirstInputType(); TypeInformation<IN2> rightInformation = getOperatorInfo().getSecondInputType(); TypeInformation<OUT> outInformation = getOperatorInfo().getOutputType(); TypeComparator<IN1> leftComparator = buildComparatorFor(0, executionConfig, leftInformation); TypeComparator<IN2> rightComparator = buildComparatorFor(1, executionConfig, rightInformation); TypeSerializer<IN1> leftSerializer = leftInformation.createSerializer(executionConfig); TypeSerializer<IN2> rightSerializer = rightInformation.createSerializer(executionConfig); OuterJoinListIterator<IN1, IN2> outerJoinIterator = new OuterJoinListIterator<>(leftInput, leftSerializer, leftComparator, rightInput, rightSerializer, rightComparator, outerJoinType); // -------------------------------------------------------------------- // Run UDF // -------------------------------------------------------------------- FlatJoinFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject(); FunctionUtils.setFunctionRuntimeContext(function, runtimeContext); FunctionUtils.openFunction(function, this.parameters); List<OUT> result = new ArrayList<>(); Collector<OUT> collector = new CopyingListCollector<>(result, outInformation.createSerializer(executionConfig)); while (outerJoinIterator.next()) { IN1 left = outerJoinIterator.getLeft(); IN2 right = outerJoinIterator.getRight(); function.join(left == null ? null : leftSerializer.copy(left), right == null ? null : rightSerializer.copy(right), collector); } FunctionUtils.closeFunction(function); return result; } @SuppressWarnings("unchecked") private <T> TypeComparator<T> buildComparatorFor(int input, ExecutionConfig executionConfig, TypeInformation<T> typeInformation) { TypeComparator<T> comparator; if (typeInformation instanceof AtomicType) { comparator = ((AtomicType<T>) typeInformation).createComparator(true, executionConfig); } else if (typeInformation instanceof CompositeType) { int[] keyPositions = getKeyColumns(input); boolean[] orders = new boolean[keyPositions.length]; Arrays.fill(orders, true); comparator = ((CompositeType<T>) typeInformation).createComparator(keyPositions, orders, 0, executionConfig); } else { throw new RuntimeException("Type information for input of type " + typeInformation.getClass() .getCanonicalName() + " is not supported. Could not generate a comparator."); } return comparator; } private static class OuterJoinListIterator<IN1, IN2> { private static enum MatchStatus { NONE_REMAINED, FIRST_REMAINED, SECOND_REMAINED, FIRST_EMPTY, SECOND_EMPTY } private OuterJoinType outerJoinType; private ListKeyGroupedIterator<IN1> leftGroupedIterator; private ListKeyGroupedIterator<IN2> rightGroupedIterator; private Iterable<IN1> currLeftSubset; private ResettableIterator currLeftIterator; private Iterable<IN2> currRightSubset; private ResettableIterator currRightIterator; private MatchStatus matchStatus; private GenericPairComparator<IN1, IN2> pairComparator; private IN1 leftReturn; private IN2 rightReturn; public OuterJoinListIterator(List<IN1> leftInput, TypeSerializer<IN1> leftSerializer, final TypeComparator<IN1> leftComparator, List<IN2> rightInput, TypeSerializer<IN2> rightSerializer, final TypeComparator<IN2> rightComparator, OuterJoinType outerJoinType) { this.outerJoinType = outerJoinType; pairComparator = new GenericPairComparator<>(leftComparator, rightComparator); leftGroupedIterator = new ListKeyGroupedIterator<>(leftInput, leftSerializer, leftComparator); rightGroupedIterator = new ListKeyGroupedIterator<>(rightInput, rightSerializer, rightComparator); // ---------------------------------------------------------------- // Sort // ---------------------------------------------------------------- Collections.sort(leftInput, new Comparator<IN1>() { @Override public int compare(IN1 o1, IN1 o2) { return leftComparator.compare(o1, o2); } }); Collections.sort(rightInput, new Comparator<IN2>() { @Override public int compare(IN2 o1, IN2 o2) { return rightComparator.compare(o1, o2); } }); } @SuppressWarnings("unchecked") private boolean next() throws IOException { boolean hasMoreElements; if ((currLeftIterator == null || !currLeftIterator.hasNext()) && (currRightIterator == null || !currRightIterator.hasNext())) { hasMoreElements = nextGroups(outerJoinType); if (hasMoreElements) { if (outerJoinType != OuterJoinType.LEFT) { currLeftIterator = new ListIteratorWrapper(currLeftSubset.iterator()); } leftReturn = (IN1) currLeftIterator.next(); if (outerJoinType != OuterJoinType.RIGHT) { currRightIterator = new ListIteratorWrapper(currRightSubset.iterator()); } rightReturn = (IN2) currRightIterator.next(); return true; } else { //no more elements return false; } } else if (currLeftIterator.hasNext() && !currRightIterator.hasNext()) { leftReturn = (IN1) currLeftIterator.next(); currRightIterator.reset(); rightReturn = (IN2) currRightIterator.next(); return true; } else { rightReturn = (IN2) currRightIterator.next(); return true; } } private boolean nextGroups(OuterJoinType outerJoinType) throws IOException { if (outerJoinType == OuterJoinType.FULL) { return nextGroups(); } else if (outerJoinType == OuterJoinType.LEFT) { boolean leftContainsElements = false; while (!leftContainsElements && nextGroups()) { currLeftIterator = new ListIteratorWrapper(currLeftSubset.iterator()); if (currLeftIterator.next() != null) { leftContainsElements = true; } currLeftIterator.reset(); } return leftContainsElements; } else if (outerJoinType == OuterJoinType.RIGHT) { boolean rightContainsElements = false; while (!rightContainsElements && nextGroups()) { currRightIterator = new ListIteratorWrapper(currRightSubset.iterator()); if (currRightIterator.next() != null) { rightContainsElements = true; } currRightIterator.reset(); } return rightContainsElements; } else { throw new IllegalArgumentException("Outer join of type '" + outerJoinType + "' not supported."); } } private boolean nextGroups() throws IOException { boolean firstEmpty = true; boolean secondEmpty = true; if (this.matchStatus != MatchStatus.FIRST_EMPTY) { if (this.matchStatus == MatchStatus.FIRST_REMAINED) { // comparator is still set correctly firstEmpty = false; } else { if (this.leftGroupedIterator.nextKey()) { this.pairComparator.setReference(leftGroupedIterator.getValues().getCurrent()); firstEmpty = false; } } } if (this.matchStatus != MatchStatus.SECOND_EMPTY) { if (this.matchStatus == MatchStatus.SECOND_REMAINED) { secondEmpty = false; } else { if (rightGroupedIterator.nextKey()) { secondEmpty = false; } } } if (firstEmpty && secondEmpty) { // both inputs are empty return false; } else if (firstEmpty && !secondEmpty) { // input1 is empty, input2 not this.currLeftSubset = Collections.singleton(null); this.currRightSubset = this.rightGroupedIterator.getValues(); this.matchStatus = MatchStatus.FIRST_EMPTY; return true; } else if (!firstEmpty && secondEmpty) { // input1 is not empty, input 2 is empty this.currLeftSubset = this.leftGroupedIterator.getValues(); this.currRightSubset = Collections.singleton(null); this.matchStatus = MatchStatus.SECOND_EMPTY; return true; } else { // both inputs are not empty final int comp = this.pairComparator.compareToReference(rightGroupedIterator.getValues().getCurrent()); if (0 == comp) { // keys match this.currLeftSubset = this.leftGroupedIterator.getValues(); this.currRightSubset = this.rightGroupedIterator.getValues(); this.matchStatus = MatchStatus.NONE_REMAINED; } else if (0 < comp) { // key1 goes first this.currLeftSubset = this.leftGroupedIterator.getValues(); this.currRightSubset = Collections.singleton(null); this.matchStatus = MatchStatus.SECOND_REMAINED; } else { // key 2 goes first this.currLeftSubset = Collections.singleton(null); this.currRightSubset = this.rightGroupedIterator.getValues(); this.matchStatus = MatchStatus.FIRST_REMAINED; } return true; } } private IN1 getLeft() { return leftReturn; } private IN2 getRight() { return rightReturn; } } }