package com.taobao.tddl.executor.cursor.impl; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import com.taobao.tddl.common.exception.TddlException; import com.taobao.tddl.common.utils.GeneralUtil; import com.taobao.tddl.executor.cursor.IAggregateCursor; import com.taobao.tddl.executor.cursor.ICursorMeta; import com.taobao.tddl.executor.cursor.ISchematicCursor; import com.taobao.tddl.executor.cursor.SchematicCursor; import com.taobao.tddl.executor.function.ExtraFunction; import com.taobao.tddl.executor.rowset.ArrayRowSet; import com.taobao.tddl.executor.rowset.IRowSet; import com.taobao.tddl.executor.utils.ExecUtils; import com.taobao.tddl.optimizer.config.table.ColumnMeta; import com.taobao.tddl.optimizer.core.datatype.DataType; import com.taobao.tddl.optimizer.core.expression.IFunction; import com.taobao.tddl.optimizer.core.expression.IFunction.FunctionType; import com.taobao.tddl.optimizer.core.expression.IOrderBy; import com.taobao.tddl.optimizer.core.expression.ISelectable; import com.taobao.tddl.common.utils.logger.Logger; import com.taobao.tddl.common.utils.logger.LoggerFactory; /** * 用来计算聚合函数,group by * * @author mengshi.sunmengshi 2013-12-3 上午10:53:31 * @since 5.0.0 */ public class AggregateCursor extends SchematicCursor implements IAggregateCursor { private final static Logger logger = LoggerFactory.getLogger(AggregateCursor.class); /** * 查询中涉及的所有聚合函数 */ protected List<IFunction> aggregates = new LinkedList<IFunction>(); /** * 查询中涉及的所有scalar函数 */ List<IFunction> scalars = new LinkedList<IFunction>(); /** * 当前节点是不是归并节点 */ boolean isMerge = false; List<ColumnMeta> groupBys = new ArrayList<ColumnMeta>(); Map<ColumnMeta, Object> currentGroupByValue = null; private boolean schemaInited = false; private ICursorMeta cursorMeta = null; boolean end = false; IRowSet firstRowSetInCurrentGroup = null; boolean isFirstTime = true; public AggregateCursor(ISchematicCursor cursor, List<IFunction> functions, List<IOrderBy> groupBycols, List<ISelectable> retColumns, boolean isMerge){ super(cursor, null, cursor.getOrderBy()); this.groupBys.addAll(ExecUtils.getColumnMetaWithLogicTablesFromOrderBys(groupBycols)); for (IFunction f : functions) { if (f.getFunctionType().equals(FunctionType.Scalar)) { this.scalars.add(f); } } this.aggregates.addAll(this.getAllAggregates(functions)); this.isMerge = isMerge; } @Override public IRowSet next() throws TddlException { initSchema(); if (end) { return null; } if (isFirstTime) { if (firstRowSetInCurrentGroup == null) { end = true; return null; } isFirstTime = false; for (IFunction aggregate : aggregates) { aggregate.getExtraFunction().clear(); } } // 初始化currentGroupByValue,并把当前第一条记录中的值放进去 if (this.groupBys != null && !this.groupBys.isEmpty()) { if (this.currentGroupByValue == null) { this.currentGroupByValue = new HashMap(); } for (ColumnMeta cm : groupBys) { if (firstRowSetInCurrentGroup == null) { currentGroupByValue = null; } else { Object value = ExecUtils.getObject(firstRowSetInCurrentGroup.getParentCursorMeta(), firstRowSetInCurrentGroup, cm.getTableName(), cm.getName()); currentGroupByValue.put(cm, value); } } } IRowSet record = new ArrayRowSet(cursorMeta.getColumns().size(), cursorMeta); IRowSet kv = firstRowSetInCurrentGroup; // 这里无论KV是否是null,第一次都应该让函数来处理 if (kv == null && aggregates.isEmpty()) { end = true; return null; } if (kv != null) { // cursorMeta后面有函数,所以以kv的meta为准 for (int i = 0; i < kv.getParentCursorMeta().getColumns().size(); i++) { ColumnMeta cm = cursorMeta.getColumns().get(i); Integer index = kv.getParentCursorMeta().getIndex(cm.getTableName(), cm.getName()); if (index == null) { index = kv.getParentCursorMeta().getIndex(cm.getTableName(), cm.getAlias()); } record.setObject(i, kv.getObject(index)); } } if (!aggregates.isEmpty() || (this.groupBys != null && !this.groupBys.isEmpty())) { do { // 如果组的值发生了变化,则返回一条记录 if (isCurrentGroupByChanged(kv)) { this.firstRowSetInCurrentGroup = kv; break; } for (IFunction aggregate : aggregates) { if (this.isMerge() && !aggregate.isNeedDistinctArg()) { ((ExtraFunction) aggregate.getExtraFunction()).serverReduce(kv); } else { ((ExtraFunction) aggregate.getExtraFunction()).serverMap(kv); } } } while ((kv = super.next()) != null); } else { kv = super.next(); firstRowSetInCurrentGroup = kv; } // 将函数的结果放到结果集中 this.putFunctionsResultInRecord(aggregates, record); // 对于aggregate函数,需要遍历所有结果集 // 当两者同时存在时,scalar函数只处理第一条结果 // mysql是这样做的 // 对于scalar函数,只需要取一条结果 for (IFunction scalar : this.scalars) { if (this.isMerge()) { ((ExtraFunction) scalar.getExtraFunction()).serverReduce(record); } else { ((ExtraFunction) scalar.getExtraFunction()).serverMap(record); } } this.putFunctionsResultInRecord(scalars, record); for (IFunction aggregate : aggregates) { ((ExtraFunction) aggregate.getExtraFunction()).clear(); } end = (kv == null); return record; } // 递归遍历所有给定的函数,得到其中的聚合函数 // 包括参数中的函数 public List<IFunction> getAllAggregates(List<IFunction> functions) { List<IFunction> aggregates = new LinkedList<IFunction>(); for (IFunction f : functions) { List<IFunction> functionsInArgs = new ArrayList<IFunction>(f.getArgs().size()); for (Object arg : f.getArgs()) { if (arg instanceof IFunction) { functionsInArgs.add((IFunction) arg); } } List<IFunction> aggregatesInArgs = this.getAllAggregates(functionsInArgs); aggregates.addAll(aggregatesInArgs); if (f.getFunctionType().equals(FunctionType.Aggregate)) { aggregates.add(f); // 聚合函数不能使用聚合函数作为参数 // 如 max(count(id))是错误的 if (!aggregatesInArgs.isEmpty()) { throw new RuntimeException("Invalid use of group function"); } } } return aggregates; } private void initSchema() throws TddlException { if (schemaInited) { return; } schemaInited = true; firstRowSetInCurrentGroup = super.next(); if (firstRowSetInCurrentGroup == null) { return; } // 把聚合的结果放在最后 ICursorMeta meta = firstRowSetInCurrentGroup.getParentCursorMeta(); List<ColumnMeta> retColumns = new ArrayList<ColumnMeta>(meta.getColumns().size() + this.aggregates.size()); retColumns.addAll(meta.getColumns()); for (IFunction c : this.aggregates) { Integer index = meta.getIndex(c.getTableName(), c.getColumnName()); if (index == null) { index = meta.getIndex(c.getTableName(), c.getAlias()); } if (index == null) { putRetColumnInMeta(c, retColumns); } } for (IFunction c : this.scalars) { Integer index = meta.getIndex(c.getTableName(), c.getColumnName()); if (index == null) { index = meta.getIndex(c.getTableName(), c.getAlias()); } if (index == null) { putRetColumnInMeta(c, retColumns); } } cursorMeta = CursorMetaImp.buildNew(retColumns, retColumns.size()); if (logger.isDebugEnabled()) { logger.warn("firstRowSetInCurrentGroup:\n" + firstRowSetInCurrentGroup); logger.warn("cursorMeta:\n" + cursorMeta); } } private boolean isCurrentGroupByChanged(IRowSet kv) { if (this.groupBys != null && !this.groupBys.isEmpty()) { if (this.currentGroupByValue == null) { return false; } if (kv == null) return true; for (ColumnMeta cm : this.currentGroupByValue.keySet()) { Object valueFromKv = ExecUtils.getObject(kv.getParentCursorMeta(), kv, cm.getTableName(), cm.getName()); Object valueCurrent = this.currentGroupByValue.get(cm); if (valueFromKv == null) { if (valueCurrent != null) { return true; } } else { if (valueCurrent == null) { return true; } if (!valueFromKv.equals(valueCurrent)) { return true; } } } } return false; } @Override public IRowSet first() throws TddlException { this.end = false; this.isFirstTime = true; super.beforeFirst(); return this.next(); } @Override public void beforeFirst() throws TddlException { schemaInited = false; this.end = false; this.isFirstTime = true; super.beforeFirst(); } public boolean isMerge() { return this.isMerge; } public void setMerge(boolean isMerge) { this.isMerge = isMerge; } void putFunctionsResultInRecord(List<IFunction> functions, IRowSet record) { for (IFunction f : functions) { Integer index = this.cursorMeta.getIndex(f.getTableName(), f.getColumnName()); if (index == null) { index = this.cursorMeta.getIndex(f.getTableName(), f.getAlias()); } Object res = f.getExtraFunction().getResult(); if (res instanceof Map) { Map<String, Object> map = (Map<String, Object>) f.getExtraFunction().getResult(); for (Entry<String, Object> en : map.entrySet()) { record.setObject(index, en.getValue()); } } else { record.setObject(index, res); } } } void putRetColumnInMeta(ISelectable column, List<ColumnMeta> metaColumns) { String columnName; columnName = column.getColumnName(); DataType type = null; // 函数在Map和Reduce过程中的返回类型可以不同 // 如Avg,map过程返回String // reduce过程中返回数字类型 if (this.isMerge()) { type = column.getDataType(); } else { if (column instanceof IFunction) { type = ((IFunction) column).getExtraFunction().getMapReturnType(); } else { type = column.getDataType(); } } ColumnMeta cm = new ColumnMeta(ExecUtils.getLogicTableName(column.getTableName()), columnName, type, column.getAlias(), true); metaColumns.add(cm); } @Override public String toString() { return toStringWithInden(0); } @Override public String toStringWithInden(int inden) { try { initSchema(); } catch (Exception e) { e.printStackTrace(); } StringBuilder sb = new StringBuilder(); String tab = GeneralUtil.getTab(inden); sb.append(tab).append("【Aggregate cursor . agg funcs").append(aggregates).append("\n"); ExecUtils.printMeta(cursorMeta, inden, sb); ExecUtils.printOrderBy(orderBys, inden, sb); sb.append(super.toStringWithInden(inden)); return sb.toString(); } }