/**
* Copyright (C) 2009-2013 FoundationDB, LLC
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package com.foundationdb.sql.optimizer.rule;
import com.foundationdb.sql.optimizer.plan.*;
import com.foundationdb.sql.optimizer.plan.Sort.OrderByExpression;
import com.foundationdb.sql.parser.ValueNode;
import com.foundationdb.sql.types.DataTypeDescriptor;
import com.foundationdb.sql.types.TypeId;
import com.foundationdb.server.error.InvalidOptimizerPropertyException;
import com.foundationdb.server.error.NoAggregateWithGroupByException;
import com.foundationdb.server.error.UnsupportedSQLException;
import com.foundationdb.server.types.TInstance;
import com.foundationdb.ais.model.Column;
import com.foundationdb.ais.model.IndexColumn;
import com.foundationdb.ais.model.TableIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/** Resolve aggregate functions and group by expressions to output
* columns of the "group table," that is, the result of aggregation.
*/
public class AggregateMapper extends BaseRule
{
private static final Logger logger = LoggerFactory.getLogger(AggregateMapper.class);
@Override
protected Logger getLogger() {
return logger;
}
@Override
public void apply(PlanContext plan) {
AggregateSourceAndFunctionFinder aggregateSourceFinder = new AggregateSourceAndFunctionFinder(plan);
List<AggregateSourceState> sources = aggregateSourceFinder.find();
List<AggregateFunctionExpression> functions = aggregateSourceFinder.getFunctions();
if (sources.isEmpty() && !functions.isEmpty()) {
// if there are AggregateFunctionExpressions but no AggregateSources
throw new UnsupportedSQLException("Aggregate not allowed in WHERE",
functions.get(0).getSQLsource());
}
// Step 1: look for outer aggregate references, and convert AggregateFunctionExpressions
// to AnnotatedAggregateFunctionExpressions
Annotator annotator = new Annotator(plan.getPlan(),
aggregateSourceFinder.getTablesToSources());
annotator.run();
// Step 2: Each run of FindHavingSources makes two passes.
// Pass 1) Map AggregateFunctions that aren't references to outer aggregates
// Pass 2) Check that all the ColumnExpressions are okay
for (AggregateSourceState source : sources) {
FindHavingSources findHavingSources = new FindHavingSources((SchemaRulesContext)plan.getRulesContext(),
source.aggregateSource,
source.containingQuery);
findHavingSources.run(source.aggregateSource);
}
// Step 3: Add all aggregates to sources, or throw an error.
AddAggregates addAggregates = new AddAggregates(plan.getPlan(),
aggregateSourceFinder.getTablesToSources());
addAggregates.run();
}
static class AnnotatedAggregateFunctionExpression extends AggregateFunctionExpression {
private AggregateSource source = null;
public AnnotatedAggregateFunctionExpression(AggregateFunctionExpression aggregateFunc) {
super(aggregateFunc.getFunction(),
aggregateFunc.getOperand(),
aggregateFunc.isDistinct(),
aggregateFunc.getSQLtype(),
aggregateFunc.getSQLsource(),
aggregateFunc.getType(),
aggregateFunc.getOption(),
aggregateFunc.getOrderBy());
}
public AnnotatedAggregateFunctionExpression(String function, ExpressionNode operand,
boolean distinct,
DataTypeDescriptor sqlType, ValueNode sqlSource,
TInstance type,
Object option, List<OrderByExpression> orderBy, AggregateSource source) {
super(function,
operand,
distinct,
sqlType,
sqlSource,
type,
option,
orderBy);
this.source = source;
}
public AnnotatedAggregateFunctionExpression setSource(AggregateSource source) {
this.source = source;
return this;
}
public AggregateSource getSource() {
return source;
}
public AggregateFunctionExpression getWithoutAnnotation() {
return new AggregateFunctionExpression(this.getFunction(),
this.getOperand(),
this.isDistinct(),
this.getSQLtype(),
this.getSQLsource(),
this.getType(),
this.getOption(),
this.getOrderBy());
}
@Override
public ExpressionNode accept(ExpressionRewriteVisitor v) {
ExpressionNode result = v.visit(this);
return result;
}
}
static class AggregateSourceFinder extends SubqueryBoundTablesTracker {
List<AggregateSourceState> sources = new ArrayList<>();
public AggregateSourceFinder(PlanContext planContext) {
super(planContext);
}
public List<AggregateSourceState> find() {
run();
return sources;
}
@Override
public boolean visit(PlanNode n) {
super.visit(n);
if (n instanceof AggregateSource)
sources.add(new AggregateSourceState((AggregateSource)n, currentQuery()));
return true;
}
}
static class AggregateSourceAndFunctionFinder extends AggregateSourceFinder {
List<AggregateFunctionExpression> functions = new ArrayList<>();
Deque<AggregateFunctionExpression> functionsStack = new ArrayDeque<>();
// collect this to use in AddAggregates
Map<TableSource, AggregateSourceState> tablesToSources = new HashMap<>();
public AggregateSourceAndFunctionFinder(PlanContext planContext) {
super(planContext);
}
public List<AggregateFunctionExpression> getFunctions() {
return functions;
}
public Map<TableSource, AggregateSourceState> getTablesToSources() {
return tablesToSources;
}
@Override
public boolean visitEnter(ExpressionNode n) {
if (n instanceof AggregateFunctionExpression) {
if (!functionsStack.isEmpty()) {
throw new UnsupportedSQLException("Cannot nest aggregate functions",
functionsStack.peek().getSQLsource());
} else {
functionsStack.push((AggregateFunctionExpression)n);
}
}
return visit(n);
}
@Override
public boolean visit(PlanNode n) {
super.visit(n);
if (n instanceof TableSource) {
if (!sources.isEmpty()) {
tablesToSources.put((TableSource)n, sources.get(sources.size()-1));
}
}
return true;
}
@Override
public boolean visit(ExpressionNode n) {
super.visit(n);
if (n instanceof AggregateFunctionExpression) {
functions.add((AggregateFunctionExpression)n);
}
return true;
}
@Override
public boolean visitLeave(ExpressionNode n) {
if (n instanceof AggregateFunctionExpression) {
functionsStack.pop();
}
return true;
}
}
static class AggregateSourceState {
AggregateSource aggregateSource;
BaseQuery containingQuery;
public AggregateSourceState(AggregateSource aggregateSource,
BaseQuery containingQuery) {
this.aggregateSource = aggregateSource;
this.containingQuery = containingQuery;
}
}
static class Annotator implements PlanVisitor, ExpressionRewriteVisitor {
PlanNode plan;
Deque<BaseQuery> subqueries = new ArrayDeque<>();
Map<TableSource, AggregateSourceState> tablesToSources;
public Annotator(PlanNode plan,
Map<TableSource, AggregateSourceState> tablesToSources) {
this.plan = plan;
this.tablesToSources = tablesToSources;
}
public void run() {
plan.accept(this);
}
public ExpressionNode annotateAggregate(AggregateFunctionExpression expr) {
// look for a reference to an outer aggregate, and save that source in the annotated function if found
AggregateSource source = null;
if (expr.getOperand() instanceof ColumnExpression) {
ColumnSource columnSource = ((ColumnExpression)expr.getOperand()).getTable();
if (columnSource instanceof TableSource && tablesToSources.containsKey((TableSource)columnSource)) {
AggregateSourceState sourceState =
tablesToSources.get((TableSource)columnSource);
if (sourceState.containingQuery != subqueries.peek() &&
subqueries.contains(sourceState.containingQuery)) {
source = sourceState.aggregateSource;
}
}
}
return new AnnotatedAggregateFunctionExpression(expr).setSource(source);
}
@Override
public boolean visitChildrenFirst(ExpressionNode n) {
return false;
}
@Override
public ExpressionNode visit(ExpressionNode n) {
if (n instanceof AggregateFunctionExpression) {
return annotateAggregate((AggregateFunctionExpression)n);
}
return n;
}
@Override
public boolean visitEnter(PlanNode n) {
if (n instanceof BaseQuery) {
subqueries.push((BaseQuery)n);
}
return visit(n);
}
@Override
public boolean visitLeave(PlanNode n) {
if (n instanceof BaseQuery) {
subqueries.pop();
}
return true;
}
@Override
public boolean visit(PlanNode n) {
return true;
}
}
static abstract class Remapper implements ExpressionRewriteVisitor, PlanVisitor {
public void remap(PlanNode n) {
while (true) {
// Keep going as long as we're feeding something we understand.
n = n.getOutput();
if (n instanceof Select) {
remap(((Select)n).getConditions());
}
else if (n instanceof Sort) {
remapA(((Sort)n).getOrderBy());
}
else if (n instanceof Project) {
Project p = (Project)n;
remap(p.getFields());
}
else if (n instanceof Limit) {
// Understood not but mapped.
}
else
break;
}
}
@SuppressWarnings("unchecked")
protected <T extends ExpressionNode> void remap(List<T> exprs) {
for (int i = 0; i < exprs.size(); i++) {
exprs.set(i, (T)exprs.get(i).accept(this));
}
}
protected void remapA(List<? extends AnnotatedExpression> exprs) {
for (AnnotatedExpression expr : exprs) {
expr.setExpression(expr.getExpression().accept(this));
}
}
@Override
public boolean visitChildrenFirst(ExpressionNode expr) {
return false;
}
@Override
public boolean visit(PlanNode n) {
return true;
}
@Override
public boolean visitEnter(PlanNode n) {
return visit(n);
}
@Override
public boolean visitLeave(PlanNode n) {
return true;
}
}
static class FindHavingSources extends Remapper {
private SchemaRulesContext rulesContext;
private AggregateSource source;
private BaseQuery query;
private Deque<BaseQuery> subqueries = new ArrayDeque<>();
private Set<ColumnSource> aggregated = new HashSet<>();
private Map<ExpressionNode,ExpressionNode> map =
new HashMap<>();
private enum State {
FINDING_SOURCES, CHECKING_ERRORS
};
private State state;
boolean hasAggregates;
private enum ImplicitAggregateSetting {
ERROR, FIRST, FIRST_IF_UNIQUE
};
private ImplicitAggregateSetting implicitAggregateSetting;
private Set<TableSource> uniqueGroupedTables;
protected ImplicitAggregateSetting getImplicitAggregateSetting() {
if (implicitAggregateSetting == null) {
String setting = rulesContext.getProperty("implicitAggregate", "error");
if ("error".equals(setting))
implicitAggregateSetting = ImplicitAggregateSetting.ERROR;
else if ("first".equals(setting))
implicitAggregateSetting = ImplicitAggregateSetting.FIRST;
else if ("firstIfUnique".equals(setting))
implicitAggregateSetting = ImplicitAggregateSetting.FIRST_IF_UNIQUE;
else
throw new InvalidOptimizerPropertyException("implicitAggregate", setting);
}
return implicitAggregateSetting;
}
public FindHavingSources(SchemaRulesContext rulesContext, AggregateSource source, BaseQuery query) {
this.rulesContext = rulesContext;
this.source = source;
this.query = query;
aggregated.add(source);
// Map all the group by expressions at the start.
// This means that if you GROUP BY x+1, you can ORDER BY
// x+1, or x+1+1, but not x+2. Postgres is like that, too.
List<ExpressionNode> groupBy = source.getGroupBy();
for (int i = 0; i < groupBy.size(); i++) {
ExpressionNode expr = groupBy.get(i);
map.put(expr, new ColumnExpression(source, i,
expr.getSQLtype(), expr.getSQLsource(), expr.getType()));
}
}
public void run(PlanNode n) {
state = State.FINDING_SOURCES;
hasAggregates = false;
remap(n);
state = State.CHECKING_ERRORS;
remap(n);
}
@Override
public void remap(PlanNode n) {
while (true) {
// Keep going as long as we're feeding something we understand.
n = n.getOutput();
if (n instanceof Select) {
remap(((Select)n).getConditions());
}
else if (n instanceof Sort) {
remapA(((Sort)n).getOrderBy());
}
else if (n instanceof Project) {
Project p = (Project)n;
remap(p.getFields());
aggregated.add(p);
}
else if (n instanceof Limit) {
// Understood not but mapped.
}
else
break;
}
}
@Override
public ExpressionNode visit(ExpressionNode expr) {
ExpressionNode nexpr = map.get(expr);
if (nexpr != null)
return nexpr;
switch (state) {
case FINDING_SOURCES:
if (expr instanceof AnnotatedAggregateFunctionExpression) {
AnnotatedAggregateFunctionExpression a = (AnnotatedAggregateFunctionExpression)expr;
nexpr = rewrite(a);
if (nexpr == null) {
if (subqueries.isEmpty() && a.getSource() == null) {
a.setSource(source);
hasAggregates = true;
}
return a;
}
return nexpr.accept(this);
}
case CHECKING_ERRORS:
if (expr instanceof ColumnExpression) {
ColumnExpression column = (ColumnExpression)expr;
ColumnSource table = column.getTable();
if ((!map.isEmpty() || hasAggregates) &&
!aggregated.contains(table) &&
!boundElsewhere(table)) {
return nonAggregate(column);
}
}
}
return expr;
}
@Override
public boolean visitEnter(PlanNode n) {
if (n instanceof BaseQuery)
subqueries.push((BaseQuery)n);
return visit(n);
}
@Override
public boolean visitLeave(PlanNode n) {
if (n instanceof BaseQuery)
subqueries.pop();
return true;
}
// Rewrite agregate functions that aren't well behaved wrt pre-aggregation.
protected ExpressionNode rewrite(AnnotatedAggregateFunctionExpression expr) {
String function = expr.getFunction().toUpperCase();
if ("AVG".equals(function)) {
ExpressionNode operand = expr.getOperand();
List<ExpressionNode> noperands = new ArrayList<>(2);
noperands.add(new AnnotatedAggregateFunctionExpression("SUM", operand, expr.isDistinct(),
operand.getSQLtype(), null,
operand.getType(), null, null,
expr.getSource()));
DataTypeDescriptor intType = new DataTypeDescriptor(TypeId.INTEGER_ID, false);
TInstance intInst = rulesContext.getTypesTranslator().typeForSQLType(intType);
noperands.add(new AnnotatedAggregateFunctionExpression("COUNT", operand, expr.isDistinct(),
intType, null, intInst, null, null,
expr.getSource()));
return new FunctionExpression("divide",
noperands,
expr.getSQLtype(), expr.getSQLsource(), expr.getType());
}
if ("VAR_POP".equals(function) ||
"VAR_SAMP".equals(function) ||
"STDDEV_POP".equals(function) ||
"STDDEV_SAMP".equals(function)) {
ExpressionNode operand = expr.getOperand();
List<ExpressionNode> noperands = new ArrayList<>(3);
noperands.add(new AnnotatedAggregateFunctionExpression("_VAR_SUM_2", operand, expr.isDistinct(),
operand.getSQLtype(), null,
operand.getType(), null, null,
expr.getSource()));
noperands.add(new AnnotatedAggregateFunctionExpression("_VAR_SUM", operand, expr.isDistinct(),
operand.getSQLtype(), null,
operand.getType(), null, null,
expr.getSource()));
DataTypeDescriptor intType = new DataTypeDescriptor(TypeId.INTEGER_ID, false);
TInstance intInst = rulesContext.getTypesTranslator().typeForSQLType(intType);
noperands.add(new AnnotatedAggregateFunctionExpression("COUNT", operand, expr.isDistinct(),
intType, null, intInst, null, null,
expr.getSource()));
return new FunctionExpression("_" + function,
noperands,
expr.getSQLtype(), expr.getSQLsource(), expr.getType());
}
return null;
}
protected ExpressionNode addKey(ExpressionNode expr) {
int position = source.addGroupBy(expr);
ColumnExpression nexpr = new ColumnExpression(source, position,
expr.getSQLtype(), expr.getSQLsource(), expr.getType());
map.put(expr, nexpr);
return nexpr;
}
protected boolean boundElsewhere(ColumnSource table) {
if (query.getOuterTables().contains(table))
return true; // Bound outside.
BaseQuery subquery = subqueries.peek();
if (subquery != null) {
if (!subquery.getOuterTables().contains(table))
return true; // Must be introduced by subquery.
}
return false;
}
// Use of a column not in GROUP BY without aggregate function.
protected ExpressionNode nonAggregate(ColumnExpression column) {
boolean isUnique = isUniqueGroupedTable(column.getTable());
ImplicitAggregateSetting setting = getImplicitAggregateSetting();
if ((setting == ImplicitAggregateSetting.ERROR) ||
((setting == ImplicitAggregateSetting.FIRST_IF_UNIQUE) && !isUnique))
throw new NoAggregateWithGroupByException(column.getSQLsource());
if (isUnique && source.getAggregates().isEmpty())
// Add unique as another key in hopes of turning the
// whole things into a distinct.
return addKey(column);
else
return new AnnotatedAggregateFunctionExpression("FIRST", column, false,
column.getSQLtype(), null, column.getType(), null, null,
source);
}
protected boolean isUniqueGroupedTable(ColumnSource columnSource) {
if (!(columnSource instanceof TableSource))
return false;
TableSource table = (TableSource)columnSource;
if (uniqueGroupedTables == null)
uniqueGroupedTables = new HashSet<>();
if (uniqueGroupedTables.contains(table))
return true;
Set<Column> columns = new HashSet<>();
for (ExpressionNode groupBy : source.getGroupBy()) {
if (groupBy instanceof ColumnExpression) {
ColumnExpression groupColumn = (ColumnExpression)groupBy;
if (groupColumn.getTable() == table) {
columns.add(groupColumn.getColumn());
}
}
}
if (columns.isEmpty()) return false;
// Find a unique index all of whose columns are in the GROUP BY.
// TODO: Use column equivalences.
find_index:
for (TableIndex index : table.getTable().getTable().getIndexes()) {
if (!index.isUnique()) continue;
for (IndexColumn indexColumn : index.getKeyColumns()) {
if (!columns.contains(indexColumn.getColumn())) {
continue find_index;
}
}
uniqueGroupedTables.add(table);
return true;
}
return false;
}
}
static class AddAggregates implements PlanVisitor, ExpressionRewriteVisitor {
PlanNode plan;
Deque<BaseQuery> subqueries = new ArrayDeque<>();
Map<TableSource, AggregateSourceState> tablesToSources;
public AddAggregates(PlanNode plan,
Map<TableSource, AggregateSourceState> tablesToSources) {
this.plan = plan;
this.tablesToSources = tablesToSources;
}
public void run() {
plan.accept(this);
}
public ExpressionNode addAggregate(AnnotatedAggregateFunctionExpression expr) {
AggregateSource source = expr.getSource();
if (source == null) {
throw new UnsupportedSQLException("Aggregate not allowed in WHERE",
expr.getSQLsource());
}
int position;
if (source.hasAggregate(expr)) {
position = source.getPosition(expr.getWithoutAnnotation());
} else {
position = source.addAggregate(expr.getWithoutAnnotation());
}
ExpressionNode nexpr = new ColumnExpression(source, position,
expr.getSQLtype(), expr.getSQLsource(), expr.getType());
return nexpr;
}
@Override
public boolean visitChildrenFirst(ExpressionNode n) {
return false;
}
@Override
public ExpressionNode visit(ExpressionNode n) {
if (n instanceof AnnotatedAggregateFunctionExpression) {
return addAggregate((AnnotatedAggregateFunctionExpression)n);
}
return n;
}
@Override
public boolean visitEnter(PlanNode n) {
if (n instanceof BaseQuery) {
subqueries.push((BaseQuery)n);
}
return visit(n);
}
@Override
public boolean visitLeave(PlanNode n) {
if (n instanceof BaseQuery) {
subqueries.pop();
}
return true;
}
@Override
public boolean visit(PlanNode n) {
return true;
}
}
}