/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.presto.operator.exchange.LocalPartitionGenerator;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.type.Type;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Integer.numberOfTrailingZeros;
import static java.lang.Math.toIntExact;
@NotThreadSafe
public class PartitionedLookupSource
implements LookupSource
{
public static Supplier<LookupSource> createPartitionedLookupSourceSupplier(List<Supplier<LookupSource>> partitions, List<Type> hashChannelTypes, boolean outer)
{
Optional<OuterPositionTracker.Factory> outerPositionTrackerFactory = outer ?
Optional.of(new OuterPositionTracker.Factory(
partitions.stream()
.map(partition -> partition.get().getJoinPositionCount())
.collect(toImmutableList())))
: Optional.empty();
return () -> new PartitionedLookupSource(
partitions.stream()
.map(Supplier::get)
.collect(toImmutableList()),
hashChannelTypes,
outerPositionTrackerFactory.map(OuterPositionTracker.Factory::create));
}
private final LookupSource[] lookupSources;
private final LocalPartitionGenerator partitionGenerator;
private final int partitionMask;
private final int shiftSize;
@Nullable
private final OuterPositionTracker outerPositionTracker;
private PartitionedLookupSource(List<? extends LookupSource> lookupSources, List<Type> hashChannelTypes, Optional<OuterPositionTracker> outerPositionTracker)
{
this.lookupSources = lookupSources.toArray(new LookupSource[lookupSources.size()]);
// this generator is only used for getJoinPosition without a rawHash and in this case
// the hash channels are always packed in a page without extra columns
int[] hashChannels = new int[hashChannelTypes.size()];
for (int i = 0; i < hashChannels.length; i++) {
hashChannels[i] = i;
}
this.partitionGenerator = new LocalPartitionGenerator(new InterpretedHashGenerator(hashChannelTypes, hashChannels), lookupSources.size());
this.partitionMask = lookupSources.size() - 1;
this.shiftSize = numberOfTrailingZeros(lookupSources.size()) + 1;
this.outerPositionTracker = outerPositionTracker.orElse(null);
}
@Override
public int getChannelCount()
{
return lookupSources[0].getChannelCount();
}
@Override
public int getJoinPositionCount()
{
throw new UnsupportedOperationException("Parallel hash can not be used in a RIGHT or FULL outer join");
}
@Override
public long getInMemorySizeInBytes()
{
return Arrays.stream(lookupSources).mapToLong(LookupSource::getInMemorySizeInBytes).sum();
}
@Override
public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage)
{
return getJoinPosition(position, hashChannelsPage, allChannelsPage, partitionGenerator.getRawHash(position, hashChannelsPage));
}
@Override
public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash)
{
int partition = partitionGenerator.getPartition(rawHash);
LookupSource lookupSource = lookupSources[partition];
long joinPosition = lookupSource.getJoinPosition(position, hashChannelsPage, allChannelsPage, rawHash);
if (joinPosition < 0) {
return joinPosition;
}
return encodePartitionedJoinPosition(partition, toIntExact(joinPosition));
}
@Override
public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
int partition = decodePartition(currentJoinPosition);
long joinPosition = decodeJoinPosition(currentJoinPosition);
LookupSource lookupSource = lookupSources[partition];
long nextJoinPosition = lookupSource.getNextJoinPosition(joinPosition, probePosition, allProbeChannelsPage);
if (nextJoinPosition < 0) {
return nextJoinPosition;
}
return encodePartitionedJoinPosition(partition, toIntExact(nextJoinPosition));
}
@Override
public boolean isJoinPositionEligible(long currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
int partition = decodePartition(currentJoinPosition);
long joinPosition = decodeJoinPosition(currentJoinPosition);
LookupSource lookupSource = lookupSources[partition];
return lookupSource.isJoinPositionEligible(joinPosition, probePosition, allProbeChannelsPage);
}
@Override
public void appendTo(long partitionedJoinPosition, PageBuilder pageBuilder, int outputChannelOffset)
{
int partition = decodePartition(partitionedJoinPosition);
int joinPosition = decodeJoinPosition(partitionedJoinPosition);
lookupSources[partition].appendTo(joinPosition, pageBuilder, outputChannelOffset);
if (outerPositionTracker != null) {
outerPositionTracker.positionVisited(partition, joinPosition);
}
}
@Override
public OuterPositionIterator getOuterPositionIterator()
{
checkState(outerPositionTracker != null, "This is not an outer lookup source");
return new PartitionedLookupOuterPositionIterator(lookupSources, outerPositionTracker.getVisitedPositions());
}
@Override
public void close()
{
if (outerPositionTracker != null) {
outerPositionTracker.commit();
}
}
private int decodePartition(long partitionedJoinPosition)
{
return (int) (partitionedJoinPosition & partitionMask);
}
private int decodeJoinPosition(long partitionedJoinPosition)
{
return toIntExact(partitionedJoinPosition >>> shiftSize);
}
private long encodePartitionedJoinPosition(int partition, int joinPosition)
{
return (((long) joinPosition) << shiftSize) | (partition);
}
private static class PartitionedLookupOuterPositionIterator
implements OuterPositionIterator
{
private final LookupSource[] lookupSources;
private final boolean[][] visitedPositions;
@GuardedBy("this")
private int currentSource;
@GuardedBy("this")
private int currentPosition;
public PartitionedLookupOuterPositionIterator(LookupSource[] lookupSources, boolean[][] visitedPositions)
{
this.lookupSources = lookupSources;
this.visitedPositions = visitedPositions;
}
@Override
public synchronized boolean appendToNext(PageBuilder pageBuilder, int outputChannelOffset)
{
while (currentSource < lookupSources.length) {
while (currentPosition < visitedPositions[currentSource].length) {
if (!visitedPositions[currentSource][currentPosition]) {
lookupSources[currentSource].appendTo(currentPosition, pageBuilder, outputChannelOffset);
currentPosition++;
return true;
}
currentPosition++;
}
currentPosition = 0;
currentSource++;
}
return false;
}
}
/**
* Each LookupSource has it's own copy of OuterPositionTracker instance.
* Each of those OuterPositionTracker must be committed after last write
* and before first read.
*
* All instances share visitedPositions array, but it is safe because each thread
* starts with visitedPositions filled with false values and marks only some positions
* to true. Since we don't care what will be the order of those writes to
* visitedPositions, writes can be without synchronization.
*
* Memory visibility between last writes in commit() and first read in
* getVisitedPositions() is guaranteed by accessing AtomicLong referenceCount
* variables in those two methods.
*/
private static class OuterPositionTracker
{
public static class Factory
{
private final boolean[][] visitedPositions;
private final AtomicLong referenceCount = new AtomicLong();
public Factory(List<Integer> positionCounts)
{
visitedPositions = new boolean[positionCounts.size()][];
for (int partition = 0; partition < visitedPositions.length; partition++) {
visitedPositions[partition] = new boolean[positionCounts.get(partition)];
}
}
public OuterPositionTracker create()
{
return new OuterPositionTracker(visitedPositions, referenceCount);
}
}
private final boolean[][] visitedPositions; // shared across multiple operators/drivers
private final AtomicLong referenceCount; // shared across multiple operators/drivers
private boolean written; // unique per each operator/driver
private OuterPositionTracker(boolean[][] visitedPositions, AtomicLong referenceCount)
{
this.visitedPositions = visitedPositions;
this.referenceCount = referenceCount;
}
/**
* No synchronization here, because it would be very expensive. Check comment above.
*/
public void positionVisited(int partition, int position)
{
if (!written) {
written = true;
incrementReferenceCount();
}
visitedPositions[partition][position] = true;
}
public void commit()
{
if (written) {
// touching atomic values ensures memory visibility between commit and getVisitedPositions
referenceCount.decrementAndGet();
}
}
public boolean[][] getVisitedPositions()
{
// touching atomic values ensures memory visibility between commit and getVisitedPositions
verify(referenceCount.get() == 0);
return visitedPositions;
}
private void incrementReferenceCount()
{
referenceCount.incrementAndGet();
}
}
}