/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.expression.BytecodeExpressions;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.InputReferenceExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.VariableReferenceExpression;
import com.facebook.presto.util.Reflection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Map;
import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.STATIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic;
import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
public class LambdaBytecodeGenerator
{
private LambdaBytecodeGenerator()
{
}
/**
* @return a MethodHandle field that represents the lambda expression
*/
public static LambdaExpressionField preGenerateLambdaExpression(
LambdaDefinitionExpression lambdaExpression,
String fieldName,
ClassDefinition classDefinition,
PreGeneratedExpressions preGeneratedExpressions,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
FunctionRegistry functionRegistry)
{
ImmutableList.Builder<Parameter> parameters = ImmutableList.builder();
ImmutableMap.Builder<String, ParameterAndType> parameterMapBuilder = ImmutableMap.builder();
parameters.add(arg("session", ConnectorSession.class));
for (int i = 0; i < lambdaExpression.getArguments().size(); i++) {
Class<?> type = Primitives.wrap(lambdaExpression.getArgumentTypes().get(i).getJavaType());
String argumentName = lambdaExpression.getArguments().get(i);
Parameter arg = arg("lambda_" + argumentName, type);
parameters.add(arg);
parameterMapBuilder.put(argumentName, new ParameterAndType(arg, type));
}
BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor(
callSiteBinder,
cachedInstanceBinder,
variableReferenceCompiler(parameterMapBuilder.build()),
functionRegistry,
preGeneratedExpressions);
return defineLambdaMethodAndField(
innerExpressionVisitor,
classDefinition,
fieldName,
parameters.build(),
lambdaExpression);
}
private static LambdaExpressionField defineLambdaMethodAndField(
BytecodeExpressionVisitor innerExpressionVisitor,
ClassDefinition classDefinition,
String fieldAndMethodName,
List<Parameter> inputParameters,
LambdaDefinitionExpression lambda)
{
Class<?> returnType = Primitives.wrap(lambda.getBody().getType().getJavaType());
MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), fieldAndMethodName, type(returnType), inputParameters);
Scope scope = method.getScope();
Variable wasNull = scope.declareVariable(boolean.class, "wasNull");
BytecodeNode compiledBody = lambda.getBody().accept(innerExpressionVisitor, scope);
method.getBody()
.putVariable(wasNull, false)
.append(compiledBody)
.append(boxPrimitiveIfNecessary(scope, returnType))
.ret(returnType);
FieldDefinition staticField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName, type(MethodHandle.class));
FieldDefinition instanceField = classDefinition.declareField(a(PRIVATE, FINAL), "binded_" + fieldAndMethodName, type(MethodHandle.class));
classDefinition.getClassInitializer().getBody()
.append(setStatic(
staticField,
invokeStatic(
Reflection.class,
"methodHandle",
MethodHandle.class,
constantClass(classDefinition.getType()),
constantString(fieldAndMethodName),
newArray(
type(Class[].class),
inputParameters.stream()
.map(Parameter::getType)
.map(BytecodeExpressions::constantClass)
.collect(toImmutableList())))));
return new LambdaExpressionField(staticField, instanceField);
}
private static RowExpressionVisitor<Scope, BytecodeNode> variableReferenceCompiler(Map<String, ParameterAndType> parameterMap)
{
return new RowExpressionVisitor<Scope, BytecodeNode>()
{
@Override
public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope)
{
throw new UnsupportedOperationException();
}
@Override
public BytecodeNode visitCall(CallExpression call, Scope scope)
{
throw new UnsupportedOperationException();
}
@Override
public BytecodeNode visitConstant(ConstantExpression literal, Scope scope)
{
throw new UnsupportedOperationException();
}
@Override
public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context)
{
throw new UnsupportedOperationException();
}
@Override
public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context)
{
ParameterAndType parameterAndType = parameterMap.get(reference.getName());
Parameter parameter = parameterAndType.getParameter();
Class<?> type = parameterAndType.getType();
return new BytecodeBlock()
.append(parameter)
.append(unboxPrimitiveIfNecessary(context, type));
}
};
}
static class LambdaExpressionField
{
private final FieldDefinition staticField;
// the instance field will be binded to "this" in constructor
private final FieldDefinition instanceField;
public LambdaExpressionField(FieldDefinition staticField, FieldDefinition instanceField)
{
this.staticField = requireNonNull(staticField, "staticField is null");
this.instanceField = requireNonNull(instanceField, "instanceField is null");
}
public FieldDefinition getInstanceField()
{
return instanceField;
}
public void generateInitialization(Variable thisVariable, BytecodeBlock block)
{
block.append(
thisVariable.setField(
instanceField,
getStatic(staticField).invoke("bindTo", MethodHandle.class, thisVariable.cast(Object.class))));
}
}
}