package com.taobao.tddl.optimizer.costbased.after; import java.util.ArrayList; import java.util.List; import java.util.Map; import com.taobao.tddl.common.jdbc.ParameterContext; import com.taobao.tddl.optimizer.core.expression.IFunction; import com.taobao.tddl.optimizer.core.expression.IFunction.FunctionType; import com.taobao.tddl.optimizer.core.expression.ISelectable; import com.taobao.tddl.optimizer.core.plan.IDataNodeExecutor; import com.taobao.tddl.optimizer.core.plan.IQueryTree; import com.taobao.tddl.optimizer.core.plan.query.IJoin; import com.taobao.tddl.optimizer.core.plan.query.IMerge; import com.taobao.tddl.optimizer.core.plan.query.IQuery; /** * avg变成count + sum 要改变columns结构 * * @author Whisper */ public class FuckAvgOptimizer implements QueryPlanOptimizer { public FuckAvgOptimizer(){ } /** * 把query中的avg换成count,sum */ @Override public IDataNodeExecutor optimize(IDataNodeExecutor dne, Map<Integer, ParameterContext> parameterSettings, Map<String, Object> extraCmd) { if (dne instanceof IMerge && ((IMerge) dne).getSubNode().size() > 1) { for (IDataNodeExecutor sub : ((IMerge) dne).getSubNode()) { expendAvgFunction(sub); } for (IDataNodeExecutor sub : ((IMerge) dne).getSubNode()) { this.optimize(sub, parameterSettings, extraCmd); } } else if (dne instanceof IJoin) { IJoin join = (IJoin) dne; // join函数,采取map模式,不需要处理avg展开 // 递归处理子节点 this.optimize(join.getLeftNode(), parameterSettings, extraCmd); this.optimize(join.getRightNode(), parameterSettings, extraCmd); } else if (dne instanceof IQuery) { IQuery query = (IQuery) dne; // 如果是子查询,采取map模式,不需要处理avg展开 if (query.isSubQuery()) { this.optimize(query.getSubQuery(), parameterSettings, extraCmd);// 递归处理子节点 } } return dne; } private boolean hasArgsAvgFunction(IFunction func) { for (Object args : func.getArgs()) { if (args instanceof IFunction && ((IFunction) args).getColumnName().startsWith("AVG(")) { return true; } } return false; } /** * 将Avg函数展开为sum/count */ private void expendAvgFunction(IDataNodeExecutor sub) { if (sub instanceof IQuery || sub instanceof IJoin) { List<ISelectable> add = new ArrayList(); List<ISelectable> remove = new ArrayList(); for (Object sel : ((IQueryTree) sub).getColumns()) { ISelectable s = (ISelectable) sel; if (s instanceof IFunction) { if (s.getColumnName().startsWith("AVG(")) { IFunction sum = (IFunction) s.copy(); sum.setExtraFunction(null); sum.setFunctionName("SUM"); sum.setColumnName(s.getColumnName().replace("AVG(", "SUM(")); if (sum.getAlias() != null) { sum.setAlias(sum.getAlias() + "1");// 加个后缀1 } IFunction count = (IFunction) s.copy(); count.setExtraFunction(null); count.setFunctionName("COUNT"); count.setColumnName(s.getColumnName().replace("AVG(", "COUNT(")); if (count.getAlias() != null) { count.setAlias(count.getAlias() + "2");// 加个后缀2 } add.add(count); add.add(sum); remove.add(s); } else { // 删除底下AVG的相关函数,比如 1 + AVG(ID) // 目前这个只能上层来进行计算 // 可能的风险:还未支持的Function计算 if (FunctionType.Scalar.equals(((IFunction) s).getFunctionType()) && hasArgsAvgFunction((IFunction) s)) { remove.add(s); } } } } if (!remove.isEmpty()) { ((IQueryTree) sub).getColumns().removeAll(remove); ((IQueryTree) sub).getColumns().addAll(add); } } } }