/**
* diqube: Distributed Query Base.
*
* Copyright (C) 2015 Bastian Gloeckle
*
* This file is part of diqube.
*
* diqube is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.diqube.execution.steps;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import org.diqube.data.column.ColumnType;
import org.diqube.execution.consumers.AbstractThreadedColumnBuiltConsumer;
import org.diqube.execution.consumers.AbstractThreadedGroupDeltaConsumer;
import org.diqube.execution.consumers.ColumnBuiltConsumer;
import org.diqube.execution.consumers.ColumnVersionBuiltConsumer;
import org.diqube.execution.consumers.DoneConsumer;
import org.diqube.execution.consumers.GenericConsumer;
import org.diqube.execution.consumers.GroupDeltaConsumer;
import org.diqube.execution.consumers.GroupIntermediaryAggregationConsumer;
import org.diqube.execution.exception.ExecutablePlanBuildException;
import org.diqube.execution.exception.ExecutablePlanExecutionException;
import org.diqube.executionenv.ExecutionEnvironment;
import org.diqube.function.AggregationFunction;
import org.diqube.function.AggregationFunction.ValueProvider;
import org.diqube.function.FunctionFactory;
import org.diqube.function.IntermediaryResult;
import org.diqube.queries.QueryRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Calculates intermediary results of group aggregation function. This is typically executed on all cluster nodes,
* whereas the intermediary results are provided to the query master which then combines the intermediary results from
* all cluster nodes for one group (which is done by {@link GroupFinalAggregationStep}).
*
* <p>
* This step does not support a {@link ColumnVersionBuiltConsumer} as the step will be executed on remotes only and not
* on the query master. While executing on remotes, no intermediary columns are supported.
*
* <p>
* Input: 1 {@link GroupDeltaConsumer}, multiple optional {@link ColumnBuiltConsumer}<br>
* Output: {@link GroupIntermediaryAggregationConsumer}
*
* @author Bastian Gloeckle
*/
public class GroupIntermediaryAggregationStep extends AbstractThreadedExecutablePlanStep {
private static final Logger logger = LoggerFactory.getLogger(GroupIntermediaryAggregationStep.class);
private AtomicBoolean groupDeltaSourceIsDone = new AtomicBoolean(false);
private ConcurrentLinkedDeque<Map<Long, List<Long>>> newGroupChanges = new ConcurrentLinkedDeque<>();
private AbstractThreadedGroupDeltaConsumer groupDeltaConsumer = new AbstractThreadedGroupDeltaConsumer(this) {
@Override
protected void allSourcesAreDone() {
GroupIntermediaryAggregationStep.this.groupDeltaSourceIsDone.set(true);
}
@Override
protected void doConsumeGroupDeltas(Map<Long, List<Long>> lastChangedGroups) {
GroupIntermediaryAggregationStep.this.newGroupChanges.add(lastChangedGroups);
}
};
private AtomicBoolean allColumnsBuilt = new AtomicBoolean(false);
private AbstractThreadedColumnBuiltConsumer columnBuiltConsumer = new AbstractThreadedColumnBuiltConsumer(this) {
@Override
protected void allSourcesAreDone() {
GroupIntermediaryAggregationStep.this.allColumnsBuilt.set(true);
}
@Override
protected void doColumnBuilt(String colName) {
}
};
private ExecutionEnvironment env;
private FunctionFactory functionFactory;
private String functionNameLowerCase;
private String outputColName;
private Map<Long, AggregationFunction<Object, Object>> aggregationFunctions = new HashMap<>();
/** can be null if no parameter is specified for the aggregation function (e.g. count()) */
private String inputColumnName;
private List<Object> constantFunctionParameters;
public GroupIntermediaryAggregationStep(int stepId, QueryRegistry queryRegistry, ExecutionEnvironment env,
FunctionFactory functionFactory, String functionNameLowerCase, String outputColName, String inputColumnName,
List<Object> constantFunctionParameters) {
super(stepId, queryRegistry);
this.env = env;
this.functionFactory = functionFactory;
this.functionNameLowerCase = functionNameLowerCase;
this.outputColName = outputColName;
this.inputColumnName = inputColumnName;
this.constantFunctionParameters = constantFunctionParameters;
}
@Override
protected void validateOutputConsumer(GenericConsumer consumer) throws IllegalArgumentException {
if (!(consumer instanceof DoneConsumer) && !(consumer instanceof GroupIntermediaryAggregationConsumer))
throw new IllegalArgumentException("Only GroupIntermediaryAggregationConsumer supported.");
}
@Override
protected void execute() {
if (columnBuiltConsumer.getNumberOfTimesWired() > 0 && !allColumnsBuilt.get())
// wait until input columns are built.
return;
List<Map<Long, List<Long>>> activeGroupDeltas = new ArrayList<>();
Map<Long, List<Long>> grpDelta;
while ((grpDelta = newGroupChanges.poll()) != null)
activeGroupDeltas.add(grpDelta);
if (activeGroupDeltas.size() > 0) {
// merge group deltas
Map<Long, List<Long>> groupDeltas = new HashMap<>();
for (Map<Long, List<Long>> activeGrpDelta : activeGroupDeltas) {
for (Entry<Long, List<Long>> activeEntry : activeGrpDelta.entrySet()) {
if (!groupDeltas.containsKey(activeEntry.getKey()))
groupDeltas.put(activeEntry.getKey(), new ArrayList<Long>());
groupDeltas.get(activeEntry.getKey()).addAll(activeEntry.getValue());
}
}
ColumnType inputColType;
if (inputColumnName == null)
inputColType = null;
else
inputColType = env.getColumnType(inputColumnName);
AggregationFunction<Object, Object> tmpFn =
functionFactory.createAggregationFunction(functionNameLowerCase, inputColType);
if (tmpFn == null)
throw new ExecutablePlanExecutionException(
"Cannot find function '" + functionNameLowerCase + "' with input data type " + inputColType);
// map from groupId to array of colShardIds, may be null if not pre-resolved.
Map<Long, Long[]> preResolvedColShardIds;
if (tmpFn.needsActualValues()) {
// we pre-resolve all values, as this should speed things up heavily if the input column is RunLength encoded.
preResolvedColShardIds = new HashMap<>();
Set<Long> allRowIds = new HashSet<>();
for (Entry<Long, List<Long>> deltaEntry : groupDeltas.entrySet())
allRowIds.addAll(deltaEntry.getValue());
Map<Long, Long> rowIdToColShardId = env.getColumnShard(inputColumnName).resolveColumnValueIdsForRows(allRowIds);
for (Entry<Long, List<Long>> deltaEntry : groupDeltas.entrySet()) {
Long[] colShardIds = new Long[deltaEntry.getValue().size()];
for (int i = 0; i < colShardIds.length; i++)
colShardIds[i] = rowIdToColShardId.get(deltaEntry.getValue().get(i));
preResolvedColShardIds.put(deltaEntry.getKey(), colShardIds);
}
logger.trace("Pre-resolved column shard IDs for {} groups.", groupDeltas.size());
} else
preResolvedColShardIds = null;
for (Entry<Long, List<Long>> groupDeltaEntry : groupDeltas.entrySet()) {
Long groupId = groupDeltaEntry.getKey();
List<Long> newRowIds = groupDeltaEntry.getValue();
if (!aggregationFunctions.containsKey(groupId)) {
AggregationFunction<Object, Object> newFn =
functionFactory.createAggregationFunction(functionNameLowerCase, inputColType);
if (newFn == null)
throw new ExecutablePlanExecutionException(
"Cannot find function '" + functionNameLowerCase + "' with input data type " + inputColType);
for (int i = 0; i < constantFunctionParameters.size(); i++)
newFn.provideConstantParameter(i, constantFunctionParameters.get(i));
aggregationFunctions.put(groupId, newFn);
}
calculateAndSendUpdates(groupId, aggregationFunctions.get(groupId), new ValueProvider<Object>() {
@Override
public Object[] getValues() {
Long[] columnValueIds;
if (preResolvedColShardIds != null && preResolvedColShardIds.containsKey(groupId))
columnValueIds = preResolvedColShardIds.get(groupId);
else
columnValueIds = env.getColumnShard(inputColumnName).resolveColumnValueIdsForRowsFlat(newRowIds);
return env.getColumnShard(inputColumnName).getColumnShardDictionary().decompressValues(columnValueIds);
}
@Override
public long size() {
return newRowIds.size();
}
@Override
public boolean isFinalSetOfValues() {
return false; // last set of calls see below
}
});
}
}
if (groupDeltaSourceIsDone.get() && newGroupChanges.isEmpty()) {
sendFinalAggregationFunctionUpdates();
forEachOutputConsumerOfType(GenericConsumer.class, c -> c.sourceIsDone());
doneProcessing();
}
}
private void sendFinalAggregationFunctionUpdates() {
for (Entry<Long, AggregationFunction<Object, Object>> e : aggregationFunctions.entrySet()) {
AggregationFunction<Object, Object> aggFn = e.getValue();
long groupId = e.getKey();
Object[] resValues;
if (aggFn.getInputType() == null)
resValues = null;
else {
switch (aggFn.getInputType()) {
case STRING:
resValues = new String[0];
break;
case LONG:
resValues = new Long[0];
break;
default:
resValues = new Double[0];
}
}
calculateAndSendUpdates(groupId, aggFn, new ValueProvider<Object>() {
@Override
public long size() {
return 0;
}
@Override
public boolean isFinalSetOfValues() {
return true;
}
@Override
public Object[] getValues() {
return resValues;
}
});
}
}
private void calculateAndSendUpdates(long groupId, AggregationFunction<Object, Object> aggFn,
ValueProvider<Object> valueProvider) {
ColumnType inputColType;
if (inputColumnName == null)
inputColType = null;
else
inputColType = env.getColumnType(inputColumnName);
IntermediaryResult oldIntermediary = new IntermediaryResult(outputColName, inputColType);
aggFn.populateIntermediary(oldIntermediary);
aggFn.addValues(valueProvider);
IntermediaryResult newIntermediary = new IntermediaryResult(outputColName, inputColType);
aggFn.populateIntermediary(newIntermediary);
logger.trace("New intermediary for group {} in col {}: new {}, old: {}", groupId, outputColName, newIntermediary,
oldIntermediary);
forEachOutputConsumerOfType(GroupIntermediaryAggregationConsumer.class,
c -> c.consumeIntermediaryAggregationResult(groupId, outputColName, oldIntermediary, newIntermediary));
}
@Override
protected List<GenericConsumer> inputConsumers() {
return new ArrayList<>(Arrays.asList(new GenericConsumer[] { groupDeltaConsumer, columnBuiltConsumer }));
}
@Override
protected String getAdditionalToStringDetails() {
return "funcName=" + functionNameLowerCase + ", outputCol=" + outputColName;
}
@Override
protected void validateWiredStatus() throws ExecutablePlanBuildException {
if (groupDeltaConsumer.getNumberOfTimesWired() != 1)
throw new ExecutablePlanBuildException("Group Delta input not wired.");
// columnBuiltConsumer can be wired optionally.
}
}