package com.taobao.tddl.optimizer.core.expression;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.taobao.tddl.common.utils.extension.ExtensionLoader;
import com.taobao.tddl.optimizer.exceptions.FunctionException;
import com.taobao.tddl.optimizer.utils.PackageUtils;
import com.taobao.tddl.optimizer.utils.PackageUtils.ClassFilter;
import com.taobao.tddl.common.utils.logger.Logger;
import com.taobao.tddl.common.utils.logger.LoggerFactory;
/**
* {@linkplain IExtraFunction}加载器,以类名做为Function Name,<strong>注意:忽略了大小写</stong>
*
* <pre>
* Function加载:
* 1. 自动扫描IExtraFunction对应Package目录下的所有Function实现
* 2. 自动扫描Extension扩展方式下的自定义实现,比如在META-INF/tddl 或 META-INF/services 添加扩展配置文件
* </pre>
*
* @author jianghang 2013-11-8 下午5:30:35
* @since 5.0.0
*/
public class ExtraFunctionManager {
private static final Logger logger = LoggerFactory.getLogger(ExtraFunctionManager.class);
private static Map<String, Class<?>> functionCaches = Maps.newConcurrentMap();
private static String DUMMAY_FUNCTION = "DUMMY";
private static String DUMMAY_TEST_FUNCTION = "DUMMYTEST";
private static IExtraFunction dummyFunction; // 缓存一下dummy,避免每次都反射创建
static {
initFunctions();
dummyFunction = getExtraFunction(DUMMAY_FUNCTION);
if (dummyFunction == null) {
dummyFunction = getExtraFunction(DUMMAY_TEST_FUNCTION);
}
}
/**
* 查找对应名字的函数类,忽略大小写
*
* @param functionName
* @return
*/
public static IExtraFunction getExtraFunction(String functionName) {
String name = buildKey(functionName);
Class clazz = functionCaches.get(name);
IExtraFunction result = null;
if (clazz == null) {
return dummyFunction;
}
if (clazz != null) {
try {
result = (IExtraFunction) clazz.newInstance();
} catch (Exception e) {
throw new FunctionException("init function failed", e);
}
}
if (result == null) {
throw new FunctionException("not found Function : " + functionName);
}
return result;
}
public static void addFuncion(Class clazz) {
String name = clazz.getSimpleName();
Class oldClazz = functionCaches.put(buildKey(name.toUpperCase()), clazz);
if (oldClazz != null) {
logger.warn(" dup function :" + name + ", old class : " + oldClazz.getName());
}
}
private static void initFunctions() {
List<Class> classes = Lists.newArrayList();
// 查找默认build-in的函数
ClassFilter filter = new ClassFilter() {
public boolean filter(Class clazz) {
int mod = clazz.getModifiers();
return !Modifier.isAbstract(mod) && !Modifier.isInterface(mod)
&& IExtraFunction.class.isAssignableFrom(clazz);
}
public boolean preFilter(String classFulName) {
return StringUtils.contains(classFulName, "function");// 包含function名字的类
}
};
classes.addAll(PackageUtils.findClassesInPackage("com.taobao.tddl", filter));
// 查找用户自定义的扩展函数
classes.addAll(ExtensionLoader.getAllExtendsionClass(IExtraFunction.class));
for (Class clazz : classes) {
addFuncion(clazz);
}
}
private static String buildKey(String name) {
if (IFunction.BuiltInFunction.ADD.equals(name)) {
return "ADD";
} else if (IFunction.BuiltInFunction.SUB.equals(name)) {
return "SUB";
} else if (IFunction.BuiltInFunction.MULTIPLY.equals(name)) {
return "MULTIPLY";
} else if (IFunction.BuiltInFunction.DIVISION.equals(name)) {
return "DIVISION";
} else if (IFunction.BuiltInFunction.MOD.equals(name)) {
return "MOD";
} else if (IFunction.BuiltInFunction.BITAND.equals(name)) {
return "BITAND";
} else if (IFunction.BuiltInFunction.BITOR.equals(name)) {
return "BITOR";
} else if (IFunction.BuiltInFunction.BITXOR.equals(name)) {
return "BITXOR";
} else if (IFunction.BuiltInFunction.BITLSHIFT.equals(name)) {
return "BITLSHIFT";
} else if (IFunction.BuiltInFunction.BITRSHIFT.equals(name)) {
return "BITRSHIFT";
}
return name;// 默认语法树中所有节点均为大写
}
}