/**
* Copyright (C) 2009-2013 FoundationDB, LLC
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package com.foundationdb.sql.optimizer.rule;
import com.foundationdb.sql.optimizer.plan.*;
import com.foundationdb.sql.optimizer.plan.Sort.OrderByExpression;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/** Turn aggregate with only keys into distinct.
*/
public class AggregateToDistinctMapper extends BaseRule
{
private static final Logger logger = LoggerFactory.getLogger(AggregateToDistinctMapper.class);
@Override
protected Logger getLogger() {
return logger;
}
@Override
public void apply(PlanContext plan) {
List<AggregateSource> sources = new AggregateSourceFinder().find(plan.getPlan());
for (AggregateSource source : sources) {
Mapper m = new Mapper(plan.getRulesContext(), source);
m.remap();
}
}
static class AggregateSourceFinder implements PlanVisitor, ExpressionVisitor {
List<AggregateSource> result = new ArrayList<>();
public List<AggregateSource> find(PlanNode root) {
root.accept(this);
return result;
}
@Override
public boolean visitEnter(PlanNode n) {
return visit(n);
}
@Override
public boolean visitLeave(PlanNode n) {
return true;
}
@Override
public boolean visit(PlanNode n) {
if (n instanceof AggregateSource) {
AggregateSource a = (AggregateSource)n;
if (a.getAggregates().isEmpty())
result.add(a);
}
return true;
}
@Override
public boolean visitEnter(ExpressionNode n) {
return visit(n);
}
@Override
public boolean visitLeave(ExpressionNode n) {
return true;
}
@Override
public boolean visit(ExpressionNode n) {
return true;
}
}
static class Mapper implements ExpressionRewriteVisitor {
private RulesContext rulesContext;
private AggregateSource source;
private Project project;
public Mapper(RulesContext rulesContext, AggregateSource source) {
this.rulesContext = rulesContext;
this.source = source;
}
public void remap() {
if (source.getGroupBy().isEmpty()) {
// if an AggregateSource is empty, get rid of it and return.
source.getOutput().replaceInput(source, source.getInput());
return;
}
project = new Project(source.getInput(), source.getGroupBy());
Distinct distinct = new Distinct(project);
source.getOutput().replaceInput(source, distinct);
PlanNode n = distinct;
while (true) {
// Keep going as long as we're feeding something we understand.
n = n.getOutput();
if (n instanceof Select) {
remap(((Select)n).getConditions());
}
else if (n instanceof Sort) {
Sort sort = (Sort)n;
List<OrderByExpression> sorts = sort.getOrderBy();
List<ExpressionNode> exprs = project.getFields();
remapA(sorts);
// Try to do the Sort for new Distinct at the same time.
// Cf. ASTStatementLoader.Loader.adjustSortsForDistinct
boolean merge = true;
BitSet used = new BitSet(exprs.size());
for (OrderByExpression orderBy : sorts) {
ExpressionNode expr = orderBy.getExpression();
if (!(expr instanceof ColumnExpression)) {
merge = false;
break;
}
ColumnExpression column = (ColumnExpression)expr;
if (column.getTable() != project) {
merge = false;
break;
}
used.set(column.getPosition());
}
if (merge) {
for (int i = 0; i < exprs.size(); i++) {
if (!used.get(i)) {
ExpressionNode expr = exprs.get(i);
ExpressionNode cexpr = new ColumnExpression(project, i,
expr.getSQLtype(),
expr.getSQLsource(),
expr.getType());
OrderByExpression orderBy = new OrderByExpression(cexpr,
sorts.get(0).isAscending());
sorts.add(orderBy);
}
}
n = moveBeneath(sort, distinct);
distinct.setImplementation(Distinct.Implementation.EXPLICIT_SORT);
}
}
else if (n instanceof Project) {
// This will commonly be equivalent to the project we just added.
List<ExpressionNode> fields = ((Project)n).getFields();
boolean unnecessary = fields.size() == project.getFields().size();
if (unnecessary) {
for (int i = 0; i < fields.size(); i++) {
ExpressionNode expr = fields.get(i);
if (!(expr instanceof ColumnExpression)) {
unnecessary = false;
break;
}
ColumnExpression column = (ColumnExpression)expr;
if (!((column.getTable() == source) &&
(column.getPosition() == i))) {
unnecessary = false;
break;
}
}
}
if (unnecessary) {
Project project2 = (Project)n;
n = project2.getInput();
project2.getOutput().replaceInput(project2, n);
}
else
remap(fields);
}
else if (n instanceof Limit) {
Limit limit = (Limit)n;
if (limit.getInput() instanceof Project) {
// One that was necessary above. Swap places
// so that Limit can apply to Distinct.
n = moveBeneath(limit, (Project)limit.getInput());
}
}
else
break;
}
}
// Move the given node beneath a new output. Returns a node
// whose output is what the original's was, so that traversal
// can continue.
protected PlanNode moveBeneath(BasePlanWithInput node,
BasePlanWithInput output) {
PlanNode next = node.getInput(); // Where to continue.
// Remove from current position.
node.getOutput().replaceInput(node, next);
// Splice below current input to desired output.
PlanNode input = output.getInput();
node.replaceInput(next, input);
output.replaceInput(input, node);
return next;
}
@SuppressWarnings("unchecked")
protected <T extends ExpressionNode> void remap(List<T> exprs ) {
for (int i = 0; i < exprs.size(); i++) {
exprs.set(i, (T)exprs.get(i).accept(this));
}
}
protected void remapA(List<? extends AnnotatedExpression> exprs) {
for (AnnotatedExpression expr : exprs) {
expr.setExpression(expr.getExpression().accept(this));
}
}
@Override
public boolean visitChildrenFirst(ExpressionNode expr) {
return false;
}
@Override
public ExpressionNode visit(ExpressionNode expr) {
if (expr instanceof ColumnExpression) {
ColumnExpression column = (ColumnExpression)expr;
if (column.getTable() == source) {
return new ColumnExpression(project, column.getPosition(),
expr.getSQLtype(), expr.getSQLsource(), expr.getType());
}
}
return expr;
}
}
}