/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation.builder; import com.facebook.presto.memory.LocalMemoryContext; import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.operator.aggregation.AccumulatorFactory; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import java.io.Closeable; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Optional; public class MergingHashAggregationBuilder implements Closeable { private final List<AccumulatorFactory> accumulatorFactories; private final AggregationNode.Step step; private final int expectedGroups; private final ImmutableList<Integer> groupByPartialChannels; private final Optional<Integer> hashChannel; private final OperatorContext operatorContext; private final Iterator<Page> sortedPages; private InMemoryHashAggregationBuilder hashAggregationBuilder; private final List<Type> groupByTypes; private final LocalMemoryContext systemMemoryContext; private final long memorySizeBeforeSpill; private final int overwriteIntermediateChannelOffset; private final JoinCompiler joinCompiler; public MergingHashAggregationBuilder( List<AccumulatorFactory> accumulatorFactories, AggregationNode.Step step, int expectedGroups, List<Type> groupByTypes, Optional<Integer> hashChannel, OperatorContext operatorContext, Iterator<Page> sortedPages, LocalMemoryContext systemMemoryContext, long memorySizeBeforeSpill, int overwriteIntermediateChannelOffset, JoinCompiler joinCompiler) { ImmutableList.Builder<Integer> groupByPartialChannels = ImmutableList.builder(); for (int i = 0; i < groupByTypes.size(); i++) { groupByPartialChannels.add(i); } this.accumulatorFactories = accumulatorFactories; this.step = AggregationNode.Step.partialInput(step); this.expectedGroups = expectedGroups; this.groupByPartialChannels = groupByPartialChannels.build(); this.hashChannel = hashChannel.isPresent() ? Optional.of(groupByTypes.size()) : hashChannel; this.operatorContext = operatorContext; this.sortedPages = sortedPages; this.groupByTypes = groupByTypes; this.systemMemoryContext = systemMemoryContext; this.memorySizeBeforeSpill = memorySizeBeforeSpill; this.overwriteIntermediateChannelOffset = overwriteIntermediateChannelOffset; this.joinCompiler = joinCompiler; rebuildHashAggregationBuilder(); } public Iterator<Page> buildResult() { return new Iterator<Page>() { private Iterator<Page> resultPages = Collections.emptyIterator(); @Override public boolean hasNext() { return sortedPages.hasNext() || resultPages.hasNext(); } @Override public Page next() { if (!resultPages.hasNext()) { rebuildHashAggregationBuilder(); long memorySize = 0; // ensure that at least one merged page will be processed // we can produce output after every page, because sortedPages does not have // hash values that span multiple pages (guaranteed by MergeHashSort) while (sortedPages.hasNext() && !shouldProduceOutput(memorySize)) { hashAggregationBuilder.processPage(sortedPages.next()); memorySize = hashAggregationBuilder.getSizeInMemory(); systemMemoryContext.setBytes(memorySize); } resultPages = hashAggregationBuilder.buildResult(); } return resultPages.next(); } }; } @Override public void close() { hashAggregationBuilder.close(); } private boolean shouldProduceOutput(long memorySize) { return (memorySizeBeforeSpill > 0 && memorySize > memorySizeBeforeSpill); } private void rebuildHashAggregationBuilder() { this.hashAggregationBuilder = new InMemoryHashAggregationBuilder( accumulatorFactories, step, expectedGroups, groupByTypes, groupByPartialChannels, hashChannel, operatorContext, DataSize.succinctBytes(0), Optional.of(overwriteIntermediateChannelOffset), joinCompiler); } }