/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.exchange;
import com.facebook.presto.operator.HashGenerator;
import com.facebook.presto.operator.InterpretedHashGenerator;
import com.facebook.presto.operator.PrecomputedHashGenerator;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
class PartitioningExchanger
implements Consumer<Page>
{
private final List<Consumer<PageReference>> buffers;
private final LongConsumer memoryTracker;
private final LocalPartitionGenerator partitionGenerator;
private final IntList[] partitionAssignments;
public PartitioningExchanger(
List<Consumer<PageReference>> partitions,
LongConsumer memoryTracker,
List<? extends Type> types,
List<Integer> partitionChannels,
Optional<Integer> hashChannel)
{
this.buffers = ImmutableList.copyOf(requireNonNull(partitions, "partitions is null"));
this.memoryTracker = requireNonNull(memoryTracker, "memoryTracker is null");
HashGenerator hashGenerator;
if (hashChannel.isPresent()) {
hashGenerator = new PrecomputedHashGenerator(hashChannel.get());
}
else {
List<Type> partitionChannelTypes = partitionChannels.stream()
.map(types::get)
.collect(toImmutableList());
hashGenerator = new InterpretedHashGenerator(partitionChannelTypes, Ints.toArray(partitionChannels));
}
partitionGenerator = new LocalPartitionGenerator(hashGenerator, buffers.size());
partitionAssignments = new IntList[partitions.size()];
for (int i = 0; i < partitionAssignments.length; i++) {
partitionAssignments[i] = new IntArrayList();
}
}
@Override
public synchronized void accept(Page page)
{
// reset the assignment lists
for (IntList partitionAssignment : partitionAssignments) {
partitionAssignment.clear();
}
// assign each row to a partition
for (int position = 0; position < page.getPositionCount(); position++) {
int partition = partitionGenerator.getPartition(position, page);
partitionAssignments[partition].add(position);
}
// build a page for each partition
Block[] sourceBlocks = page.getBlocks();
Block[] outputBlocks = new Block[sourceBlocks.length];
for (int partition = 0; partition < buffers.size(); partition++) {
List<Integer> positions = partitionAssignments[partition];
if (!positions.isEmpty()) {
for (int i = 0; i < sourceBlocks.length; i++) {
outputBlocks[i] = sourceBlocks[i].copyPositions(positions);
}
Page pageSplit = new Page(positions.size(), outputBlocks);
memoryTracker.accept(pageSplit.getRetainedSizeInBytes());
buffers.get(partition).accept(new PageReference(pageSplit, 1, () -> memoryTracker.accept(-pageSplit.getRetainedSizeInBytes())));
}
}
}
}