package com.taobao.tddl.executor.function.aggregate; import com.taobao.tddl.executor.function.AggregateFunction; import com.taobao.tddl.optimizer.core.datatype.DataType; import com.taobao.tddl.optimizer.core.datatype.DataTypeUtil; import com.taobao.tddl.optimizer.core.expression.IFunction; import com.taobao.tddl.optimizer.core.expression.ISelectable; import com.taobao.tddl.optimizer.exceptions.FunctionException; /** * Avg函数处理比较特殊,会将AVG转化为SUM + COUNT,拿到所有库的数据后再计算AVG * * @since 5.0.0 */ public class Avg extends AggregateFunction { private Long count = 0L; private Object total = null; @Override public void serverMap(Object[] args) throws FunctionException { count++; Object o = args[0]; DataType type = getSumType(); if (o != null) { if (total == null) { total = type.convertFrom(o); } else { total = type.getCalculator().add(total, o); } } } @Override public void serverReduce(Object[] args) throws FunctionException { if (args[0] == null || args[1] == null) { return; } count += DataType.LongType.convertFrom(args[1]); Object o = args[0]; DataType type = getSumType(); if (total == null) { total = type.convertFrom(o); } else { total = type.getCalculator().add(total, o); } } @Override public String getDbFunction() { return bulidAvgSql(function); } private String bulidAvgSql(IFunction func) { String colName = func.getColumnName(); StringBuilder sb = new StringBuilder(); if (func.getAlias() != null) {// 如果有别名,需要和FuckAvgOptimizer中保持一致 sb.append(func.getAlias() + "1").append(",").append(func.getAlias() + "2"); } else { sb.append(colName.replace("AVG", "SUM")); sb.append(",").append(colName.replace("AVG", "COUNT")); } return sb.toString(); } @Override public Object getResult() { DataType type = this.getReturnType(); if (total == null) { return type.getCalculator().divide(0L, count); } else { return type.getCalculator().divide(total, count); } } @Override public void clear() { this.total = null; this.count = 0L; } @Override public DataType getReturnType() { return getMapReturnType(); } @Override public DataType getMapReturnType() { Object[] args = function.getArgs().toArray(); DataType type = null; if (args[0] instanceof ISelectable) { type = ((ISelectable) args[0]).getDataType(); } if (type == null) { type = DataTypeUtil.getTypeOfObject(args[0]); } if (type == DataType.BigIntegerType) { // 如果是大整数,返回bigDecimal return DataType.BigDecimalType; } else { // 尽可能都返回为BigDecimalType,double类型容易出现精度问题,会和mysql出现误差 // [zhuoxue.yll, 2516885.8000] // [zhuoxue.yll, 2516885.799999999813735485076904296875] // return DataType.DoubleType; return DataType.BigDecimalType; } } public DataType getSumType() { Object[] args = function.getArgs().toArray(); DataType type = null; if (args[0] instanceof ISelectable) { type = ((ISelectable) args[0]).getDataType(); } if (type == null) { type = DataTypeUtil.getTypeOfObject(args[0]); } if (type == DataType.IntegerType || type == DataType.ShortType) { return DataType.LongType; } else { return type; } } }