/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.operators.shipping;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.util.MathUtils;
/**
* The output emitter decides to which of the possibly multiple output channels a record is sent.
* It implement routing based on hash-partitioning, broadcasting, round-robin, custom partition
* functions, etc.
*
* @param <T> The type of the element handled by the emitter.
*/
public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {
/** the shipping strategy used by this output emitter */
private final ShipStrategyType strategy;
/** the reused array defining target channels */
private int[] channels;
/** counter to go over channels round robin */
private int nextChannelToSendTo = 0;
/** the comparator for hashing / sorting */
private final TypeComparator<T> comparator;
private Object[][] partitionBoundaries; // the partition boundaries for range partitioning
private DataDistribution distribution; // the data distribution to create the partition boundaries for range partitioning
private final Partitioner<Object> partitioner;
private TypeComparator[] flatComparators;
private Object[] keys;
private Object[] extractedKeys;
// ------------------------------------------------------------------------
// Constructors
// ------------------------------------------------------------------------
/**
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...)
* and uses the supplied task index perform a round robin distribution.
*
* @param strategy The distribution strategy to be used.
*/
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup) {
this(strategy, indexInSubtaskGroup, null, null, null);
}
/**
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...)
* and uses the supplied comparator to hash / compare records for partitioning them deterministically.
*
* @param strategy The distribution strategy to be used.
* @param comparator The comparator used to hash / compare the records.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
this(strategy, 0, comparator, null, null);
}
@SuppressWarnings("unchecked")
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup,
TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distribution) {
if (strategy == null) {
throw new NullPointerException();
}
this.strategy = strategy;
this.nextChannelToSendTo = indexInSubtaskGroup;
this.comparator = comparator;
this.partitioner = (Partitioner<Object>) partitioner;
this.distribution = distribution;
switch (strategy) {
case PARTITION_CUSTOM:
extractedKeys = new Object[1];
case FORWARD:
case PARTITION_HASH:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
channels = new int[1];
break;
case PARTITION_RANGE:
channels = new int[1];
if (comparator != null) {
this.flatComparators = comparator.getFlatComparators();
this.keys = new Object[flatComparators.length];
}
break;
case BROADCAST:
break;
default:
throw new IllegalArgumentException("Invalid shipping strategy for OutputEmitter: " + strategy.name());
}
if (strategy == ShipStrategyType.PARTITION_CUSTOM && partitioner == null) {
throw new NullPointerException("Partitioner must not be null when the ship strategy is set to custom partitioning.");
}
}
// ------------------------------------------------------------------------
// Channel Selection
// ------------------------------------------------------------------------
@Override
public final int[] selectChannels(SerializationDelegate<T> record, int numberOfChannels) {
switch (strategy) {
case FORWARD:
return forward();
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record.getInstance(), numberOfChannels);
case BROADCAST:
return broadcast(numberOfChannels);
case PARTITION_CUSTOM:
return customPartition(record.getInstance(), numberOfChannels);
case PARTITION_RANGE:
return rangePartition(record.getInstance(), numberOfChannels);
default:
throw new UnsupportedOperationException("Unsupported distribution strategy: " + strategy.name());
}
}
// --------------------------------------------------------------------------------------------
private int[] forward() {
return this.channels;
}
private int[] robin(int numberOfChannels) {
int nextChannel = this.nextChannelToSendTo;
if (nextChannel >= numberOfChannels) {
if (nextChannel == numberOfChannels) {
nextChannel = 0;
} else {
nextChannel %= numberOfChannels;
}
}
this.channels[0] = nextChannel;
this.nextChannelToSendTo = nextChannel + 1;
return this.channels;
}
private int[] broadcast(int numberOfChannels) {
if (channels == null || channels.length != numberOfChannels) {
channels = new int[numberOfChannels];
for (int i = 0; i < numberOfChannels; i++) {
channels[i] = i;
}
}
return channels;
}
private int[] hashPartitionDefault(T record, int numberOfChannels) {
int hash = this.comparator.hash(record);
this.channels[0] = MathUtils.murmurHash(hash) % numberOfChannels;
return this.channels;
}
private final int[] rangePartition(final T record, int numberOfChannels) {
if (this.channels == null || this.channels.length != 1) {
this.channels = new int[1];
}
if (this.partitionBoundaries == null) {
this.partitionBoundaries = new Object[numberOfChannels - 1][];
for (int i = 0; i < numberOfChannels - 1; i++) {
this.partitionBoundaries[i] = this.distribution.getBucketBoundary(i, numberOfChannels);
}
}
if (numberOfChannels == this.partitionBoundaries.length + 1) {
final Object[][] boundaries = this.partitionBoundaries;
// bin search the bucket
int low = 0;
int high = this.partitionBoundaries.length - 1;
while (low <= high) {
final int mid = (low + high) >>> 1;
final int result = compareRecordAndBoundary(record, boundaries[mid]);
if (result > 0) {
low = mid + 1;
} else if (result < 0) {
high = mid - 1;
} else {
this.channels[0] = mid;
return this.channels;
}
}
this.channels[0] = low; // key not found, but the low index is the target
// bucket, since the boundaries are the upper bound
return this.channels;
} else {
throw new IllegalStateException(
"The number of channels to partition among is inconsistent with the partitioners state.");
}
}
private int[] customPartition(T record, int numberOfChannels) {
if (channels == null) {
channels = new int[1];
extractedKeys = new Object[1];
}
try {
if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
final Object key = extractedKeys[0];
channels[0] = partitioner.partition(key, numberOfChannels);
return channels;
}
else {
throw new RuntimeException("Inconsistency in the key comparator - comparator extracted more than one field.");
}
}
catch (Throwable t) {
throw new RuntimeException("Error while calling custom partitioner.", t);
}
}
private final int compareRecordAndBoundary(T record, Object[] boundary) {
this.comparator.extractKeys(record, keys, 0);
if (flatComparators.length != keys.length || flatComparators.length > boundary.length) {
throw new RuntimeException("Can not compare keys with boundary due to mismatched length.");
}
for (int i = 0; i < flatComparators.length; i++) {
int result = flatComparators[i].compare(keys[i], boundary[i]);
if (result != 0) {
return result;
}
}
return 0;
}
}