/** * diqube: Distributed Query Base. * * Copyright (C) 2015 Bastian Gloeckle * * This file is part of diqube. * * diqube is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.diqube.execution.steps; import java.lang.reflect.Array; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.function.LongConsumer; import java.util.stream.Collectors; import java.util.stream.LongStream; import org.diqube.data.column.ColumnShard; import org.diqube.data.column.ColumnType; import org.diqube.data.types.dbl.DoubleColumnShard; import org.diqube.data.types.lng.LongColumnShard; import org.diqube.data.types.str.StringColumnShard; import org.diqube.execution.consumers.AbstractThreadedColumnBuiltConsumer; import org.diqube.execution.consumers.ColumnBuiltConsumer; import org.diqube.execution.consumers.DoneConsumer; import org.diqube.execution.consumers.GenericConsumer; import org.diqube.execution.exception.ExecutablePlanBuildException; import org.diqube.execution.exception.ExecutablePlanExecutionException; import org.diqube.executionenv.ExecutionEnvironment; import org.diqube.executionenv.querystats.QueryableColumnShard; import org.diqube.executionenv.util.ColumnPatternUtil; import org.diqube.executionenv.util.ColumnPatternUtil.ColumnPatternContainer; import org.diqube.executionenv.util.ColumnPatternUtil.LengthColumnMissingException; import org.diqube.function.AggregationFunction; import org.diqube.function.AggregationFunction.ValueProvider; import org.diqube.function.FunctionFactory; import org.diqube.loader.LoaderColumnInfo; import org.diqube.loader.columnshard.ColumnShardBuilder; import org.diqube.loader.columnshard.ColumnShardBuilderFactory; import org.diqube.loader.columnshard.ColumnShardBuilderManager; import org.diqube.queries.QueryRegistry; import org.diqube.queries.QueryUuid; import org.diqube.queries.QueryUuid.QueryUuidThreadState; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.Iterables; /** * Step that aggregates the values of multiple columns in a single row to a single value. * * In contrast to {@link GroupIntermediaryAggregationStep} and {@link GroupFinalAggregationStep}, this does <b>not</b> * aggregate values of multiple rows. Therefore this can be fully executed on query remotes. * * <p> * Input: 1 optional {@link ColumnBuiltConsumer}, <br> * Output: {@link ColumnBuiltConsumer} * * @author Bastian Gloeckle */ public class ColumnAggregationStep extends AbstractThreadedExecutablePlanStep { private static final Logger logger = LoggerFactory.getLogger(ColumnAggregationStep.class); private static final int BATCH_SIZE = ColumnShardBuilder.PROPOSAL_ROWS / 2; // approx work on half ColumnPages. private AtomicBoolean allColumnsAreBuilt = new AtomicBoolean(false); private AbstractThreadedColumnBuiltConsumer columnBuiltConsumer = new AbstractThreadedColumnBuiltConsumer(this) { @Override protected void allSourcesAreDone() { allColumnsAreBuilt.set(true); } @Override protected void doColumnBuilt(String colName) { // noop, we try new each time. } }; private FunctionFactory functionFactory; private String functionNameLowerCase; private String outputColName; private ExecutionEnvironment defaultEnv; private ColumnShardBuilderFactory columnShardBuilderFactory; private String inputColumnNamePattern; private Function<ColumnType, ColumnShardBuilderManager> columnShardBuilderManagerSupplier; private ColumnPatternUtil columnPatternUtil; private List<Object> constantFunctionParameters; public ColumnAggregationStep(int stepId, QueryRegistry queryRegistry, ExecutionEnvironment defaultEnv, ColumnPatternUtil columnPatternUtil, ColumnShardBuilderFactory columnShardBuilderFactory, FunctionFactory functionFactory, String functionNameLowerCase, String outputColName, String inputColumnNamePattern, List<Object> constantFunctionParameters) { super(stepId, queryRegistry); this.defaultEnv = defaultEnv; this.columnPatternUtil = columnPatternUtil; this.columnShardBuilderFactory = columnShardBuilderFactory; this.functionFactory = functionFactory; this.functionNameLowerCase = functionNameLowerCase; this.outputColName = outputColName; this.inputColumnNamePattern = inputColumnNamePattern; this.constantFunctionParameters = constantFunctionParameters; } @Override public void initialize() { columnShardBuilderManagerSupplier = (outputColType) -> { LoaderColumnInfo columnInfo = new LoaderColumnInfo(outputColType); return columnShardBuilderFactory.createColumnShardBuilderManager(columnInfo, defaultEnv.getFirstRowIdInShard()); }; }; @Override protected void execute() { boolean lastRun = allColumnsAreBuilt.get(); // validate if all "length" columns are available and all [index] columns, too - we do this by looking for all // columns with all indices that are contained in the length columns (= the maximum). Set<String> allColNames; ColumnPatternContainer columnPatternContainer; try { columnPatternContainer = columnPatternUtil.findColNamesForColNamePattern(defaultEnv, inputColumnNamePattern); allColNames = columnPatternContainer.getMaximumColumnPatternsSinglePattern(); } catch (LengthColumnMissingException e) { if (lastRun) throw new ExecutablePlanExecutionException("When trying to aggregate column values, not all repeated " + "columns were available. It was expected to have repeated columns where the input column pattern " + "specifies '[*]'. Perhaps not all of these columns are repeated columns?"); return; } if (allColNames.isEmpty()) throw new ExecutablePlanExecutionException("Input col name pattern did not contain '[*]'."); // all length columns are available, check all the [index] columns now. boolean notAllColsAvailable = allColNames.stream().anyMatch(requiredCol -> defaultEnv.getColumnShard(requiredCol) == null); if (notAllColsAvailable) { logger.trace("Columns {} missing. Not proceeding.", allColNames.stream() .filter(reqiredCol -> defaultEnv.getColumnShard(reqiredCol) == null).collect(Collectors.toList())); if (lastRun) throw new ExecutablePlanExecutionException("When trying to aggregate column values, not all repeated " + "columns were available. It was expected to have repeated columns where the input column pattern " + "specifies '[*]'. Perhaps not all of these columns are repeated columns?"); return; } // Ok, all columns that we need seem to be available. logger.trace("Starting to column aggregate with output col {}", outputColName); ColumnType inputColType = defaultEnv.getColumnShard(Iterables.getFirst(allColNames, null)).getColumnType(); AggregationFunction<Object, Object> tmpFunction = functionFactory.createAggregationFunction(functionNameLowerCase, inputColType); if (tmpFunction == null) throw new ExecutablePlanExecutionException( "Cannot find function '" + functionNameLowerCase + "' with input data type " + inputColType); ColumnShardBuilderManager colShardBuilderManager = columnShardBuilderManagerSupplier.apply(tmpFunction.getOutputType()); long lastRowIdInShard = defaultEnv.getLastRowIdInShard(); final Map<String, Integer> finalAllColNames = new HashMap<>(); int tmp = 0; for (String colName : allColNames) finalAllColNames.put(colName, tmp++); final ColumnPatternContainer finalColumnPatternContainer = columnPatternContainer; QueryUuidThreadState uuidState = QueryUuid.getCurrentThreadState(); // work on all rowIds with a specific batch size LongStream.rangeClosed(defaultEnv.getFirstRowIdInShard(), lastRowIdInShard). // parallel().filter(l -> (l - defaultEnv.getFirstRowIdInShard()) % BATCH_SIZE == 0).forEach(new LongConsumer() { @Override public void accept(long firstRowId) { QueryUuid.setCurrentThreadState(uuidState); try { // to not have to decompress a lot of single values later, lets decompress all values of the current batch // for all columns. // This might resolve some columns that we later do not need for some columns, but anyway, in general it // should be faster. logger.trace("Resolving colValue IDs for batch {}", firstRowId); Class<?> valueClass = null; Object[][] valuesByCol = new Object[finalAllColNames.size()][]; List<Long> rowIds = LongStream.range(firstRowId, Math.min(firstRowId + BATCH_SIZE, lastRowIdInShard + 1)) .mapToObj(Long::valueOf).collect(Collectors.toList()); for (Entry<String, Integer> inputColNameEntry : finalAllColNames.entrySet()) { String inputColName = inputColNameEntry.getKey(); QueryableColumnShard colShard = defaultEnv.getColumnShard(inputColName); Long[] colValueIds = colShard.resolveColumnValueIdsForRowsFlat(rowIds); Object[] values = colShard.getColumnShardDictionary().decompressValues(colValueIds); if (valueClass == null) valueClass = values[0].getClass(); valuesByCol[inputColNameEntry.getValue()] = values; } logger.trace("ColValue IDs for batch {} resolved.", firstRowId); final Class<?> finalValueClass = valueClass; logger.trace("Starting to apply aggregation function to all rows in batch {}", firstRowId); Object[] resValueArray = null; for (long rowId = firstRowId; rowId < firstRowId + BATCH_SIZE && rowId <= lastRowIdInShard; rowId++) { // Ok, lets work on this single row. Let's first find all the column names that are important for this // row - based on the value of the "length" columns at this row. The other rows (with indices >= length) // will contain some sort of default values, which we do not want to include in the aggregation! final long finalRowId = rowId; Set<String> colNamesForCurRow = finalColumnPatternContainer.getColumnPatternsSinglePattern(finalRowId); AggregationFunction<Object, Object> aggFunction = functionFactory.createAggregationFunction(functionNameLowerCase, inputColType); for (int i = 0; i < constantFunctionParameters.size(); i++) aggFunction.provideConstantParameter(i, constantFunctionParameters.get(i)); // add the values to the aggregation, resolve them using the pre-computed arrays from above. aggFunction.addValues(new ValueProvider<Object>() { @Override public Object[] getValues() { Object[] res = (Object[]) Array.newInstance(finalValueClass, colNamesForCurRow.size()); int i = 0; int colIndices[] = new int[colNamesForCurRow.size()]; for (String colName : colNamesForCurRow) { colIndices[i++] = finalAllColNames.get(colName); } int rowIndex = (int) (finalRowId - firstRowId); for (int j = 0; j < colIndices.length; j++) { res[j] = valuesByCol[colIndices[j]][rowIndex]; } return res; } @Override public long size() { return colNamesForCurRow.size(); } @Override public boolean isFinalSetOfValues() { return true; // we will not use this AggregationFunction object again. } }); Object resValue = aggFunction.calculate(); // check if we still need to create the result array if (resValueArray == null) { if (firstRowId + BATCH_SIZE <= lastRowIdInShard) resValueArray = (Object[]) Array.newInstance(resValue.getClass(), BATCH_SIZE); else resValueArray = (Object[]) Array.newInstance(resValue.getClass(), (int) (lastRowIdInShard - firstRowId + 1)); } resValueArray[(int) (rowId - firstRowId)] = resValue; } logger.trace("Aggregation function applied to all rows in batch {}", firstRowId); colShardBuilderManager.addValues(outputColName, resValueArray, firstRowId); } finally { QueryUuid.clearCurrent(); } } }); QueryUuid.setCurrentThreadState(uuidState); if (Thread.interrupted()) { // If we were interrupted, exit quietly before we start to build the new col. logger.info("Interrupted. Stopping processing."); doneProcessing(); return; } logger.trace("Building output column {}", outputColName); ColumnShard outputCol = colShardBuilderManager.buildAndFree(outputColName); logger.trace("Column {} built.", outputColName); switch (outputCol.getColumnType()) { case STRING: defaultEnv.storeTemporaryStringColumnShard((StringColumnShard) outputCol); break; case LONG: defaultEnv.storeTemporaryLongColumnShard((LongColumnShard) outputCol); break; case DOUBLE: defaultEnv.storeTemporaryDoubleColumnShard((DoubleColumnShard) outputCol); break; } forEachOutputConsumerOfType(ColumnBuiltConsumer.class, c -> c.columnBuilt(outputColName)); forEachOutputConsumerOfType(GenericConsumer.class, c -> c.sourceIsDone()); doneProcessing(); } @Override protected void validateOutputConsumer(GenericConsumer consumer) throws IllegalArgumentException { if (!(consumer instanceof DoneConsumer) && !(consumer instanceof ColumnBuiltConsumer)) throw new IllegalArgumentException("Only ColumnBuiltConsumer supported."); } @Override protected List<GenericConsumer> inputConsumers() { return Arrays.asList(columnBuiltConsumer); } @Override protected void validateWiredStatus() throws ExecutablePlanBuildException { // noop. if input is wired it's fine, if not, that's fine too. } @Override protected String getAdditionalToStringDetails() { return "outputColName=" + outputColName; } }