/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.exchange; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PartitioningHandle; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import java.io.Closeable; import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; import static com.facebook.presto.operator.exchange.LocalExchangeSink.finishedLocalExchangeSink; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; @ThreadSafe public class LocalExchange { private static final DataSize DEFAULT_MAX_BUFFERED_BYTES = new DataSize(32, MEGABYTE); private final List<Type> types; private final Supplier<Consumer<Page>> exchangerSupplier; private final List<LocalExchangeSource> sources; private final LocalExchangeMemoryManager memoryManager; @GuardedBy("this") private boolean allSourcesFinished; @GuardedBy("this") private boolean noMoreSinkFactories; @GuardedBy("this") private final Set<LocalExchangeSinkFactory> openSinkFactories = new HashSet<>(); @GuardedBy("this") private final Set<LocalExchangeSink> sinks = new HashSet<>(); public LocalExchange( PartitioningHandle partitioning, int defaultConcurrency, List<? extends Type> types, List<Integer> partitionChannels, Optional<Integer> partitionHashChannel) { this(partitioning, defaultConcurrency, types, partitionChannels, partitionHashChannel, DEFAULT_MAX_BUFFERED_BYTES); } public LocalExchange( PartitioningHandle partitioning, int defaultConcurrency, List<? extends Type> types, List<Integer> partitionChannels, Optional<Integer> partitionHashChannel, DataSize maxBufferedBytes) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); int bufferCount; if (partitioning.equals(SINGLE_DISTRIBUTION)) { bufferCount = 1; checkArgument(partitionChannels.isEmpty(), "Gather exchange must not have partition channels"); } else if (partitioning.equals(FIXED_BROADCAST_DISTRIBUTION)) { bufferCount = defaultConcurrency; checkArgument(partitionChannels.isEmpty(), "Broadcast exchange must not have partition channels"); } else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) { bufferCount = defaultConcurrency; checkArgument(partitionChannels.isEmpty(), "Random exchange must not have partition channels"); } else if (partitioning.equals(FIXED_HASH_DISTRIBUTION)) { bufferCount = defaultConcurrency; checkArgument(!partitionChannels.isEmpty(), "Partitioned exchange must have partition channels"); } else { throw new IllegalArgumentException("Unsupported local exchange partitioning " + partitioning); } ImmutableList.Builder<LocalExchangeSource> sources = ImmutableList.builder(); for (int i = 0; i < bufferCount; i++) { sources.add(new LocalExchangeSource(types, source -> checkAllSourcesFinished())); } this.sources = sources.build(); List<Consumer<PageReference>> buffers = this.sources.stream() .map(buffer -> (Consumer<PageReference>) buffer::addPage) .collect(toImmutableList()); this.memoryManager = new LocalExchangeMemoryManager(maxBufferedBytes.toBytes()); if (partitioning.equals(SINGLE_DISTRIBUTION)) { exchangerSupplier = () -> new BroadcastExchanger(buffers, memoryManager::updateMemoryUsage); } else if (partitioning.equals(FIXED_BROADCAST_DISTRIBUTION)) { exchangerSupplier = () -> new BroadcastExchanger(buffers, memoryManager::updateMemoryUsage); } else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) { exchangerSupplier = () -> new RandomExchanger(buffers, memoryManager::updateMemoryUsage); } else if (partitioning.equals(FIXED_HASH_DISTRIBUTION)) { exchangerSupplier = () -> new PartitioningExchanger(buffers, memoryManager::updateMemoryUsage, types, partitionChannels, partitionHashChannel); } else { throw new IllegalArgumentException("Unsupported local exchange partitioning " + partitioning); } } public List<Type> getTypes() { return types; } public int getBufferCount() { return sources.size(); } public long getBufferedBytes() { return memoryManager.getBufferedBytes(); } public synchronized LocalExchangeSinkFactory createSinkFactory() { checkState(!noMoreSinkFactories, "No more sink factories already set"); LocalExchangeSinkFactory newFactory = new LocalExchangeSinkFactory(this); openSinkFactories.add(newFactory); return newFactory; } public LocalExchangeSource getSource(int partitionIndex) { return sources.get(partitionIndex); } private void checkAllSourcesFinished() { checkNotHoldsLock(this); if (!sources.stream().allMatch(LocalExchangeSource::isFinished)) { return; } // all sources are finished, so finish the sinks ImmutableList<LocalExchangeSink> openSinks; synchronized (this) { allSourcesFinished = true; openSinks = ImmutableList.copyOf(sinks); sinks.clear(); } // since all sources are finished there is no reason to allow new pages to be added // this can happen with a limit query openSinks.forEach(LocalExchangeSink::finish); checkAllSinksComplete(); } private LocalExchangeSink createSink(LocalExchangeSinkFactory factory) { checkNotHoldsLock(this); synchronized (this) { checkState(openSinkFactories.contains(factory), "Factory is already closed"); if (allSourcesFinished) { // all sources have completed so return a sink that is already finished return finishedLocalExchangeSink(types, memoryManager); } // Note: exchanger can be stateful so create a new one for each sink Consumer<Page> exchanger = exchangerSupplier.get(); LocalExchangeSink sink = new LocalExchangeSink(types, exchanger, memoryManager, this::sinkFinished); sinks.add(sink); return sink; } } private void sinkFinished(LocalExchangeSink sink) { checkNotHoldsLock(this); synchronized (this) { sinks.remove(sink); } checkAllSinksComplete(); } private void noMoreSinkFactories() { checkNotHoldsLock(this); synchronized (this) { noMoreSinkFactories = true; } checkAllSinksComplete(); } private void sinkFactoryClosed(LocalExchangeSinkFactory sinkFactory) { checkNotHoldsLock(this); synchronized (this) { openSinkFactories.remove(sinkFactory); } checkAllSinksComplete(); } private void checkAllSinksComplete() { checkNotHoldsLock(this); synchronized (this) { if (!noMoreSinkFactories || !openSinkFactories.isEmpty() || !sinks.isEmpty()) { return; } } sources.forEach(LocalExchangeSource::finish); memoryManager.setNoBlockOnFull(); } private static void checkNotHoldsLock(Object lock) { checkState(!Thread.holdsLock(lock), "Can not execute this method while holding a lock"); } // Sink factory is entirely a pass thought to LocalExchange. // This class only exists as a separate entity to deal with the complex lifecycle caused // by operator factories (e.g., duplicate and noMoreSinkFactories). @ThreadSafe public static class LocalExchangeSinkFactory implements Closeable { private final LocalExchange exchange; private LocalExchangeSinkFactory(LocalExchange exchange) { this.exchange = requireNonNull(exchange, "exchange is null"); } public List<Type> getTypes() { return exchange.getTypes(); } public LocalExchangeSink createSink() { return exchange.createSink(this); } public LocalExchangeSinkFactory duplicate() { return exchange.createSinkFactory(); } @Override public void close() { exchange.sinkFactoryClosed(this); } public void noMoreSinkFactories() { exchange.noMoreSinkFactories(); } } }