/** * Copyright 2013-2014 Recruit Technologies Co., Ltd. and contributors * (see CONTRIBUTORS.md) * * Licensed under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. A copy of the * License is distributed with this work in the LICENSE.md file. You may * also obtain a copy of the License from * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.gennai.gungnir.topology.udf; import static org.gennai.gungnir.GungnirConfig.*; import static org.gennai.gungnir.GungnirConst.*; import java.io.IOException; import java.lang.reflect.Array; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Type; import java.nio.file.Paths; import java.util.Arrays; import java.util.EnumSet; import java.util.List; import java.util.Map; import org.gennai.gungnir.ql.FunctionEntity; import org.gennai.gungnir.tuple.schema.FieldType; import org.gennai.gungnir.utils.GungnirUtils; import org.gennai.gungnir.utils.ScalaConverters; import scala.reflect.ScalaSignature; import com.google.common.collect.Lists; import com.google.common.collect.Maps; public abstract class BaseInvokeFunction extends BaseFunction<Object> implements UserDefined { private static final long serialVersionUID = SERIAL_VERSION_UID; private FunctionEntity function; private transient String classPath; private transient Class<?> functionClass; private transient Map<Type, Type> primitiveTypesMap; protected BaseInvokeFunction(FunctionEntity function, String classPath) { this.function = function; this.classPath = classPath; } protected BaseInvokeFunction(BaseInvokeFunction c) { super(c); this.function = c.function; this.classPath = c.classPath; } @Override public FunctionEntity getFunction() { return function; } public String getClassPath() { if (classPath == null) { classPath = getConfig().getString(CLASS_PATH); } return classPath; } protected Class<?> getFunctionClass() throws IOException, ClassNotFoundException { if (functionClass == null) { String className = function.getLocation().substring(0, function.getLocation().length() - 6); if (getClassPath() != null) { ClassLoader classLoader = GungnirUtils.addToClassPath(Paths.get(getClassPath())); if (classLoader != null) { functionClass = Class.forName(className, true, classLoader); } else { functionClass = Class.forName(className); } } else { functionClass = Class.forName(className); } } return functionClass; } protected static final class InvokableMethod { private Method method; private boolean isScala; private Class<?>[] parameterTypes; private InvokableMethod(Method method) { this.method = method; if (method.getDeclaringClass().getAnnotation(ScalaSignature.class) != null) { isScala = true; } this.parameterTypes = new Class<?>[method.getParameterTypes().length]; for (int i = 0; i < method.getParameterTypes().length; i++) { if (ScalaConverters.isScalaType(method.getParameterTypes()[i])) { parameterTypes[i] = ScalaConverters.asJavaType(method.getParameterTypes()[i]); } else { parameterTypes[i] = method.getParameterTypes()[i]; } } } protected String getName() { return method.getName(); } private boolean isVarArgs() { return method.isVarArgs(); } private Class<?>[] getParameterTypes() { return parameterTypes; } protected Object invoke(Object obj, Object[] args) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { int paramLen = method.getParameterTypes().length; if (isScala) { if (method.isVarArgs()) { Class<?> compType = method.getParameterTypes()[paramLen - 1].getComponentType(); Object varArgs = Array.newInstance(compType, args.length - paramLen + 1); for (int i = 0, j = paramLen - 1; i < Array.getLength(varArgs); i++, j++) { Array.set(varArgs, i, ScalaConverters.asScala(args[j])); } Object[] args2 = new Object[paramLen]; for (int i = 0; i < args2.length - 1; i++) { args2[i] = ScalaConverters.asScala(args[i]); } args2[args2.length - 1] = varArgs; args = args2; } else { for (int i = 0; i < paramLen; i++) { args[i] = ScalaConverters.asScala(args[i]); } } } else { if (method.isVarArgs()) { Class<?> compType = method.getParameterTypes()[paramLen - 1].getComponentType(); Object varArgs = Array.newInstance(compType, args.length - paramLen + 1); for (int i = 0, j = paramLen - 1; i < Array.getLength(varArgs); i++, j++) { Array.set(varArgs, i, args[j]); } Object[] args2 = new Object[paramLen]; for (int i = 0; i < args2.length - 1; i++) { args2[i] = args[i]; } args2[args2.length - 1] = varArgs; args = args2; } } Object ret = method.invoke(obj, args); if (isScala) { return ScalaConverters.asJava(ret); } else { return ret; } } } protected List<InvokableMethod> getMethods(Class<?> functionClass, String name) { List<InvokableMethod> invokableMethods = Lists.newArrayList(); Method[] methods = functionClass.getMethods(); for (Method method : methods) { if (method.getName().equals(name) && (method.getParameterTypes().length == getParameters().length || (method.isVarArgs() && method.getParameterTypes().length - 1 <= getParameters().length))) { invokableMethods.add(new InvokableMethod(method)); } } return invokableMethods; } private Type getPrimitiveType(Type type) { if (primitiveTypesMap == null) { primitiveTypesMap = Maps.newHashMap(); EnumSet<FieldType.TypeDef> typeNames = EnumSet.allOf(FieldType.TypeDef.class); for (FieldType.TypeDef typeName : typeNames) { if (typeName.getPrimitiveType() != null) { primitiveTypesMap.put(typeName.getJavaType(), typeName.getPrimitiveType()); } } } return primitiveTypesMap.get(type); } private int matchParamType(Class<?> paramType, Object[] args, int index) { if (paramType.isArray()) { Class<?> argType = null; boolean isNull = true; for (int i = index; i < args.length; i++) { if (args[i] != null) { if (argType == null) { argType = args[i].getClass(); } else if (args[i].getClass() != argType) { argType = Object.class; } isNull = false; } } Class<?> compType = paramType.getComponentType(); if (isNull) { if (compType == Object.class) { return 2; } else { return 1; } } else { if (compType == Object.class) { return 1; } else { if (compType.isPrimitive()) { Type primitiveType = getPrimitiveType(argType); if (primitiveType != null && compType == primitiveType) { return 3; } } else { if (compType == argType || compType.isAssignableFrom(argType)) { return 4; } } return 0; } } } else { if (args[index] == null) { if (paramType == Object.class) { return 4; } else { return 3; } } else { if (paramType == Object.class) { return 2; } else { if (paramType.isPrimitive()) { Type primitiveType = getPrimitiveType(args[index].getClass()); if (primitiveType != null && paramType == primitiveType) { return 5; } } else { if (paramType == args[index].getClass() || paramType.isAssignableFrom(args[index].getClass())) { return 6; } } return 0; } } } } protected List<InvokableMethod> findInvokeMethods(List<InvokableMethod> methods, Object[] args) throws IllegalAccessException, InvocationTargetException { List<InvokableMethod> invokeMethods = null; int maxPriority = 0; for (InvokableMethod method : methods) { int priority = 0; if (args.length == 0) { if (method.getParameterTypes().length == 0) { priority = 2; } else if (method.isVarArgs() && method.getParameterTypes().length == 1) { priority = 1; } else { priority = 0; break; } } else { for (int i = 0; i < method.getParameterTypes().length; i++) { int p = matchParamType(method.getParameterTypes()[i], args, i); if (p > 0) { priority += p; } else { priority = 0; break; } } } if (priority > 0) { if (priority > maxPriority) { maxPriority = priority; invokeMethods = Lists.newArrayList(method); } else if (priority == maxPriority) { invokeMethods.add(method); } } } return invokeMethods; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(function.getName()); try { sb.append('('); if (getParameters() != null) { for (int i = 0; i < getParameters().length; i++) { if (i > 0) { sb.append(", "); } if (getParameters()[i].getClass().isArray()) { sb.append(Arrays.toString((Object[]) getParameters()[i])); } else { sb.append(getParameters()[i].toString()); } } } sb.append(')'); if (getAliasName() != null) { sb.append(" AS "); sb.append(getAliasName()); } } catch (Exception e) { sb.append(e); } return sb.toString(); } }