/* * Licensed to Crate under one or more contributor license agreements. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. Crate licenses this file * to you under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial * agreement. */ package io.crate.operation.projectors; import com.google.common.collect.Iterables; import io.crate.analyze.symbol.AggregateMode; import io.crate.breaker.RamAccountingContext; import io.crate.breaker.SizeEstimator; import io.crate.breaker.SizeEstimatorFactory; import io.crate.data.Input; import io.crate.data.Row; import io.crate.data.RowN; import io.crate.operation.aggregation.AggregationFunction; import io.crate.operation.collect.CollectExpression; import io.crate.types.DataType; import javax.annotation.Nullable; import java.util.*; import java.util.function.BiConsumer; import java.util.function.BinaryOperator; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collector; /** * Collector implementation which uses {@link Aggregator}s and {@code keyInputs} * to group rows by key and aggregate the grouped values. * * @param <K> type of the key */ public class GroupingCollector<K> implements Collector<Row, Map<K, Object[]>, Iterable<Row>> { private final CollectExpression<Row, ?>[] expressions; private final AggregationFunction[] aggregations; private final AggregateMode mode; private final Input[][] inputs; private final RamAccountingContext ramAccountingContext; private final BiConsumer<K, Object[]> applyKeyToCells; private final int numKeyColumns; private final SizeEstimator<K> keySizeEstimator; private final Function<Row, K> keyExtractor; static GroupingCollector<Object> singleKey(CollectExpression<Row, ?>[] expressions, AggregateMode mode, AggregationFunction[] aggregations, Input[][] inputs, RamAccountingContext ramAccountingContext, Input<?> keyInput, DataType keyType) { return new GroupingCollector<>( expressions, aggregations, mode, inputs, ramAccountingContext, (key, cells) -> cells[0] = key, 1, SizeEstimatorFactory.create(keyType), row -> keyInput.value() ); } static GroupingCollector<List<Object>> manyKeys(CollectExpression<Row, ?>[] expressions, AggregateMode mode, AggregationFunction[] aggregations, Input[][] inputs, RamAccountingContext ramAccountingContext, List<Input<?>> keyInputs, List<? extends DataType> keyTypes) { return new GroupingCollector<>( expressions, aggregations, mode, inputs, ramAccountingContext, GroupingCollector::applyKeysToCells, keyInputs.size(), new MultiSizeEstimator(keyTypes), row -> evalKeyInputs(keyInputs) ); } private static List<Object> evalKeyInputs(List<Input<?>> keyInputs) { List<Object> key = new ArrayList<>(keyInputs.size()); for (Input<?> keyInput : keyInputs) { key.add(keyInput.value()); } return key; } private static void applyKeysToCells(List<Object> keys, Object[] cells) { for (int i = 0; i < keys.size(); i++) { cells[i] = keys.get(i); } } private GroupingCollector(CollectExpression<Row, ?>[] expressions, AggregationFunction[] aggregations, AggregateMode mode, Input[][] inputs, RamAccountingContext ramAccountingContext, BiConsumer<K, Object[]> applyKeyToCells, int numKeyColumns, SizeEstimator<K> keySizeEstimator, Function<Row, K> keyExtractor) { this.expressions = expressions; this.aggregations = aggregations; this.mode = mode; this.inputs = inputs; this.ramAccountingContext = ramAccountingContext; this.applyKeyToCells = applyKeyToCells; this.numKeyColumns = numKeyColumns; this.keySizeEstimator = keySizeEstimator; this.keyExtractor = keyExtractor; } @Override public Supplier<Map<K, Object[]>> supplier() { return HashMap::new; } @Override public BiConsumer<Map<K, Object[]>, Row> accumulator() { return this::onNextRow; } @Override public BinaryOperator<Map<K, Object[]>> combiner() { return (state1, state2) -> { throw new UnsupportedOperationException("combine not supported"); }; } @Override public Function<Map<K, Object[]>, Iterable<Row>> finisher() { return this::mapToRows; } @Override public Set<Characteristics> characteristics() { return Collections.emptySet(); } private void onNextRow(Map<K, Object[]> statesByKey, Row row) { for (CollectExpression<Row, ?> expression : expressions) { expression.setNextRow(row); } K key = keyExtractor.apply(row); Object[] states = statesByKey.get(key); if (states == null) { addNewEntry(statesByKey, key); } else { for (int i = 0; i < aggregations.length; i++) { states[i] = mode.onRow(ramAccountingContext, aggregations[i], states[i], inputs[i]); } } } private void addNewEntry(Map<K, Object[]> statesByKey, K key) { Object[] states; states = new Object[aggregations.length]; for (int i = 0; i < aggregations.length; i++) { AggregationFunction aggregation = aggregations[i]; states[i] = mode.onRow( ramAccountingContext, aggregation, aggregation.newState(ramAccountingContext), inputs[i]); } ramAccountingContext.addBytes( // key size + 32 bytes for entry + 4 bytes for increased capacity RamAccountingContext.roundUp(keySizeEstimator.estimateSize(key) + 36L)); statesByKey.put(key, states); } private Iterable<Row> mapToRows(Map<K, Object[]> statesByKey) { return Iterables.transform(statesByKey.entrySet(), new com.google.common.base.Function<Map.Entry<K, Object[]>, Row>() { RowN row = new RowN(numKeyColumns + aggregations.length); Object[] cells = new Object[row.numColumns()]; { row.cells(cells); } @Nullable @Override public Row apply(@Nullable Map.Entry<K, Object[]> input) { assert input != null : "input must not be null"; applyKeyToCells.accept(input.getKey(), cells); int c = numKeyColumns; Object[] states = input.getValue(); for (int i = 0; i < states.length; i++) { cells[c] = mode.finishCollect(ramAccountingContext, aggregations[i], states[i]); c++; } return row; } }); } private static class MultiSizeEstimator extends SizeEstimator<List<Object>> { private final List<SizeEstimator<Object>> subEstimators; MultiSizeEstimator(List<? extends DataType> keyTypes) { subEstimators = new ArrayList<>(keyTypes.size()); for (DataType keyType : keyTypes) { subEstimators.add(SizeEstimatorFactory.create(keyType)); } } @Override public long estimateSize(@Nullable List<Object> value) { assert value != null && value.size() == subEstimators.size() : "value must have the same number of items as there are keyTypes/sizeEstimators"; long size = 0; for (int i = 0; i < value.size(); i++) { size += subEstimators.get(i).estimateSize(value.get(i)); } return size; } } }