/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.presto.memory.AggregatedMemoryContext;
import com.facebook.presto.memory.LocalMemoryContext;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.Iterators;
import java.io.Closeable;
import java.util.Iterator;
import java.util.List;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
/**
* This class performs merge of previously hash sorted pages streams.
* <p>
* Positions are compared using their hash value. It is possible
* that two distinct values to have same hash value, thus returned
* stream of Pages can have interleaved positions with same hash value.
*/
public class MergeHashSort
implements Closeable
{
private final AggregatedMemoryContext memoryContext;
public MergeHashSort(AggregatedMemoryContext memoryContext)
{
this.memoryContext = memoryContext;
}
/**
* Rows with same hash value are guaranteed to be in the same result page.
*/
public Iterator<Page> merge(List<Type> keyTypes, List<Type> allTypes, List<Iterator<Page>> channels)
{
List<Iterator<PagePosition>> channelIterators = channels.stream()
.map(channel -> new SingleChannelPagePositions(channel, memoryContext.newLocalMemoryContext()))
.collect(toList());
int[] hashChannels = new int[keyTypes.size()];
for (int i = 0; i < keyTypes.size(); i++) {
hashChannels[i] = i;
}
HashGenerator hashGenerator = new InterpretedHashGenerator(keyTypes, hashChannels);
return new PageRewriteIterator(
hashGenerator,
allTypes,
Iterators.mergeSorted(
channelIterators,
(PagePosition left, PagePosition right) -> comparePages(hashGenerator, left, right)),
memoryContext.newLocalMemoryContext());
}
private static int comparePages(HashGenerator hashGenerator, PagePosition left, PagePosition right)
{
if (left.isPositionOutOfPage() && right.isPositionOutOfPage()) {
return 0;
}
if (left.isPositionOutOfPage()) {
return -1;
}
if (right.isPositionOutOfPage()) {
return 1;
}
long leftHash = hashGenerator.hashPosition(left.getPosition(), left.getPage());
long rightHash = hashGenerator.hashPosition(right.getPosition(), right.getPage());
return Long.compare(leftHash, rightHash);
}
@Override
public void close()
{
memoryContext.close();
}
static class PagePosition
{
private final Page page;
private final int position;
public PagePosition(Page page, int position)
{
this.page = requireNonNull(page, "page is null");
this.position = requireNonNull(position, "position is null");
}
public Page getPage()
{
return page;
}
public int getPosition()
{
return position;
}
public boolean isPositionOutOfPage()
{
return position >= page.getPositionCount();
}
}
public interface PagePositions
extends Iterator<PagePosition>
{
}
public static class SingleChannelPagePositions
implements PagePositions
{
private final Iterator<Page> channel;
private final LocalMemoryContext memoryContext;
private PagePosition current;
public SingleChannelPagePositions(Iterator<Page> channel, LocalMemoryContext memoryContext)
{
this.channel = requireNonNull(channel, "channel is null");
this.memoryContext = memoryContext;
}
@Override
public boolean hasNext()
{
return channel.hasNext() || (current != null && current.getPosition() + 1 < current.getPage().getPositionCount());
}
@Override
public PagePosition next()
{
if (current == null || current.getPosition() + 1 >= current.getPage().getPositionCount()) {
current = new PagePosition(channel.next(), 0);
memoryContext.setBytes(current.getPage().getRetainedSizeInBytes());
}
else {
current = new PagePosition(current.getPage(), current.getPosition() + 1);
}
return current;
}
}
/**
* This class rewrites iterator over PagePosition to iterator over Pages.
*/
public static class PageRewriteIterator
implements Iterator<Page>
{
private final List<Type> allTypes;
private final Iterator<PagePosition> pagePositions;
private final HashGenerator hashGenerator;
private final PageBuilder builder;
private final LocalMemoryContext memoryContext;
private PagePosition currentPage = null;
public PageRewriteIterator(HashGenerator hashGenerator, List<Type> allTypes, Iterator<PagePosition> pagePositions, LocalMemoryContext memoryContext)
{
this.hashGenerator = hashGenerator;
this.allTypes = allTypes;
this.pagePositions = pagePositions;
this.builder = new PageBuilder(allTypes);
this.memoryContext = memoryContext;
}
@Override
public boolean hasNext()
{
return currentPage != null || pagePositions.hasNext();
}
@Override
public Page next()
{
builder.reset();
if (currentPage == null) {
currentPage = pagePositions.next();
}
PagePosition previousPage = currentPage;
while (comparePages(hashGenerator, currentPage, previousPage) == 0 || !builder.isFull()) {
if (!currentPage.isPositionOutOfPage()) {
builder.declarePosition();
for (int column = 0; column < allTypes.size(); column++) {
Type type = allTypes.get(column);
type.appendTo(currentPage.getPage().getBlock(column), currentPage.getPosition(), builder.getBlockBuilder(column));
}
previousPage = currentPage;
memoryContext.setBytes(builder.getRetainedSizeInBytes());
}
if (pagePositions.hasNext()) {
currentPage = pagePositions.next();
}
else {
currentPage = null;
break;
}
}
return builder.build();
}
}
}