/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.gen; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.bytecode.expression.BytecodeExpression; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.operator.project.ConstantPageProjection; import com.facebook.presto.operator.project.InputChannels; import com.facebook.presto.operator.project.InputPageProjection; import com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter; import com.facebook.presto.operator.project.PageFilter; import com.facebook.presto.operator.project.PageProjection; import com.facebook.presto.operator.project.SelectedPositions; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.LambdaExpressionField; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.DeterminismEvaluator; import com.facebook.presto.sql.relational.Expressions; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.Signatures; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Primitives; import javax.inject.Inject; import java.util.List; import java.util.Set; import java.util.TreeSet; import java.util.function.Consumer; import java.util.function.Supplier; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.and; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.not; import static com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.LambdaAndTryExpressionExtractor.extractLambdaAndTryExpressions; import static com.facebook.presto.sql.gen.TryCodeGenerator.defineTryMethod; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class PageFunctionCompiler { private final Metadata metadata; private final DeterminismEvaluator determinismEvaluator; @Inject public PageFunctionCompiler(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); this.determinismEvaluator = new DeterminismEvaluator(metadata.getFunctionRegistry()); } public Supplier<PageProjection> compileProjection(RowExpression projection) { requireNonNull(projection, "projection is null"); if (projection instanceof InputReferenceExpression) { InputReferenceExpression input = (InputReferenceExpression) projection; InputPageProjection projectionFunction = new InputPageProjection(input.getField(), input.getType()); return () -> projectionFunction; } if (projection instanceof ConstantExpression) { ConstantExpression constant = (ConstantExpression) projection; ConstantPageProjection projectionFunction = new ConstantPageProjection(constant.getValue(), constant.getType()); return () -> projectionFunction; } PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(projection); CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = defineProjectionClass(result.getRewrittenExpression(), result.getInputChannels(), callSiteBinder); Class<? extends PageProjection> projectionClass; try { projectionClass = defineClass(classDefinition, PageProjection.class, callSiteBinder.getBindings(), getClass().getClassLoader()); } catch (Exception e) { throw new PrestoException(COMPILER_ERROR, e); } return () -> { try { return projectionClass.newInstance(); } catch (ReflectiveOperationException e) { throw new PrestoException(COMPILER_ERROR, e); } }; } private ClassDefinition defineProjectionClass(RowExpression projection, InputChannels inputChannels, CallSiteBinder callSiteBinder) { ClassDefinition classDefinition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName(PageProjection.class.getSimpleName()), type(Object.class), type(PageProjection.class)); FieldDefinition blockBuilderField = classDefinition.declareField(a(PRIVATE), "blockBuilder", BlockBuilder.class); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); generatePageProjectMethod(classDefinition, blockBuilderField); PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projection); generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, projection, blockBuilderField); // getType BytecodeExpression type = invoke(callSiteBinder.bind(projection.getType(), Type.class), "type"); classDefinition.declareMethod(a(PUBLIC), "getType", type(Type.class)) .getBody() .append(type.ret()); // isDeterministic classDefinition.declareMethod(a(PUBLIC), "isDeterministic", type(boolean.class)) .getBody() .append(constantBoolean(determinismEvaluator.isDeterministic(projection)).ret()); // getInputChannels classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class)) .getBody() .append(invoke(callSiteBinder.bind(inputChannels, InputChannels.class), "getInputChannels").ret()); // toString String toStringResult = toStringHelper(classDefinition.getType() .getJavaClassName()) .add("projection", projection) .toString(); classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class)) .getBody() // bind constant via invokedynamic to avoid constant pool issues due to large strings .append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString").ret()); // constructor generateConstructor(classDefinition, cachedInstanceBinder, preGeneratedExpressions, method -> { Variable thisVariable = method.getThis(); BytecodeBlock body = method.getBody(); body.append(thisVariable.setField( blockBuilderField, type.invoke("createBlockBuilder", BlockBuilder.class, newInstance(BlockBuilderStatus.class), constantInt(1)))); }); return classDefinition; } private static MethodDefinition generatePageProjectMethod(ClassDefinition classDefinition, FieldDefinition blockBuilder) { Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), "project", type(Block.class), ImmutableList.<Parameter>builder() .add(session) .add(page) .add(selectedPositions) .build()); Scope scope = method.getScope(); Variable thisVariable = method.getThis(); BytecodeBlock body = method.getBody(); Variable from = scope.declareVariable("from", body, selectedPositions.invoke("getOffset", int.class)); Variable to = scope.declareVariable("to", body, add(from, selectedPositions.invoke("size", int.class))); Variable positions = scope.declareVariable(int[].class, "positions"); Variable index = scope.declareVariable(int.class, "index"); IfStatement ifStatement = new IfStatement() .condition(selectedPositions.invoke("isList", boolean.class)); body.append(ifStatement); ifStatement.ifTrue(new BytecodeBlock() .append(positions.set(selectedPositions.invoke("getPositions", int[].class))) .append(new ForLoop("positions loop") .initialize(index.set(from)) .condition(lessThan(index, to)) .update(index.increment()) .body(thisVariable.invoke("project", void.class, session, page, positions.getElement(index))))); ifStatement.ifFalse(new ForLoop("range based loop") .initialize(index.set(from)) .condition(lessThan(index, to)) .update(index.increment()) .body(thisVariable.invoke("project", void.class, session, page, index))); Variable block = scope.declareVariable(Block.class, "block"); body.append(block.set(thisVariable.getField(blockBuilder).invoke("build", Block.class))) .append(thisVariable.setField( blockBuilder, thisVariable.getField(blockBuilder).invoke("newBlockBuilderLike", BlockBuilder.class, newInstance(BlockBuilderStatus.class)))) .append(block.ret()); return method; } private MethodDefinition generateProjectMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, PreGeneratedExpressions preGeneratedExpressions, RowExpression projection, FieldDefinition blockBuilder) { Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); Parameter position = arg("position", int.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), "project", type(void.class), ImmutableList.<Parameter>builder() .add(session) .add(page) .add(position) .build()); method.comment("Projection: %s", projection.toString()); Scope scope = method.getScope(); BytecodeBlock body = method.getBody(); Variable thisVariable = method.getThis(); declareBlockVariables(projection, page, scope, body); Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), metadata.getFunctionRegistry(), preGeneratedExpressions); body.append(thisVariable.getField(blockBuilder)) .append(projection.accept(visitor, scope)) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .ret(); return method; } public Supplier<PageFilter> compileFilter(RowExpression filter) { requireNonNull(filter, "filter is null"); PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(filter); CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = defineFilterClass(result.getRewrittenExpression(), result.getInputChannels(), callSiteBinder); Class<? extends PageFilter> functionClass; try { functionClass = defineClass(classDefinition, PageFilter.class, callSiteBinder.getBindings(), getClass().getClassLoader()); } catch (Exception e) { throw new PrestoException(COMPILER_ERROR, filter.toString(), e.getCause()); } return () -> { try { return functionClass.newInstance(); } catch (ReflectiveOperationException e) { throw new PrestoException(COMPILER_ERROR, e); } }; } private ClassDefinition defineFilterClass(RowExpression filter, InputChannels inputChannels, CallSiteBinder callSiteBinder) { ClassDefinition classDefinition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName(PageFilter.class.getSimpleName()), type(Object.class), type(PageFilter.class)); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter); generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, filter); FieldDefinition selectedPositions = classDefinition.declareField(a(PRIVATE), "selectedPositions", boolean[].class); generatePageFilterMethod(classDefinition, selectedPositions); // isDeterministic classDefinition.declareMethod(a(PUBLIC), "isDeterministic", type(boolean.class)) .getBody() .append(constantBoolean(determinismEvaluator.isDeterministic(filter))) .retBoolean(); // getInputChannels classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class)) .getBody() .append(invoke(callSiteBinder.bind(inputChannels, InputChannels.class), "getInputChannels")) .retObject(); // toString String toStringResult = toStringHelper(classDefinition.getType() .getJavaClassName()) .add("filter", filter) .toString(); classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class)) .getBody() // bind constant via invokedynamic to avoid constant pool issues due to large strings .append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString")) .retObject(); // constructor generateConstructor(classDefinition, cachedInstanceBinder, preGeneratedExpressions, method -> { Variable thisVariable = method.getScope().getThis(); method.getBody().append(thisVariable.setField(selectedPositions, newArray(type(boolean[].class), 0))); }); return classDefinition; } private static MethodDefinition generatePageFilterMethod(ClassDefinition classDefinition, FieldDefinition selectedPositionsField) { Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), "filter", type(SelectedPositions.class), ImmutableList.<Parameter>builder() .add(session) .add(page) .build()); Scope scope = method.getScope(); Variable thisVariable = method.getThis(); BytecodeBlock body = method.getBody(); Variable positionCount = scope.declareVariable("positionCount", body, page.invoke("getPositionCount", int.class)); body.append(new IfStatement("grow selectedPositions if necessary") .condition(lessThan(thisVariable.getField(selectedPositionsField).length(), positionCount)) .ifTrue(thisVariable.setField(selectedPositionsField, newArray(type(boolean[].class), positionCount)))); Variable selectedPositions = scope.declareVariable("selectedPositions", body, thisVariable.getField(selectedPositionsField)); Variable position = scope.declareVariable(int.class, "position"); body.append(new ForLoop() .initialize(position.set(constantInt(0))) .condition(lessThan(position, positionCount)) .update(position.increment()) .body(selectedPositions.setElement(position, thisVariable.invoke("filter", boolean.class, session, page, position)))); body.append(invokeStatic( PageFilter.class, "positionsArrayToSelectedPositions", SelectedPositions.class, selectedPositions, positionCount) .ret()); return method; } private MethodDefinition generateFilterMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, PreGeneratedExpressions preGeneratedExpressions, RowExpression filter) { Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); Parameter position = arg("position", int.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), "filter", type(boolean.class), ImmutableList.<Parameter>builder() .add(session) .add(page) .add(position) .build()); method.comment("Filter: %s", filter.toString()); Scope scope = method.getScope(); BytecodeBlock body = method.getBody(); declareBlockVariables(filter, page, scope, body); Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), metadata.getFunctionRegistry(), preGeneratedExpressions); Variable result = scope.declareVariable(boolean.class, "result"); body.append(filter.accept(visitor, scope)) // store result so we can check for null .putVariable(result) .append(and(not(wasNullVariable), result).ret()); return method; } private PreGeneratedExpressions generateMethodsForLambdaAndTry( ClassDefinition containerClassDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression expression) { Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(expression)); ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder(); ImmutableMap.Builder<LambdaDefinitionExpression, LambdaExpressionField> lambdaFieldMap = ImmutableMap.builder(); int counter = 0; for (RowExpression lambdaOrTryExpression : lambdaAndTryExpressions) { if (lambdaOrTryExpression instanceof CallExpression) { CallExpression tryExpression = (CallExpression) lambdaOrTryExpression; verify(!Signatures.TRY.equals(tryExpression.getSignature().getName())); Parameter session = arg("session", ConnectorSession.class); List<Parameter> blocks = toBlockParameters(getInputChannels(tryExpression.getArguments())); Parameter position = arg("position", int.class); BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), metadata.getFunctionRegistry(), new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build())); List<Parameter> inputParameters = ImmutableList.<Parameter>builder() .add(session) .addAll(blocks) .add(position) .build(); MethodDefinition tryMethod = defineTryMethod( innerExpressionVisitor, containerClassDefinition, "try_" + counter, inputParameters, Primitives.wrap(tryExpression.getType().getJavaType()), tryExpression, callSiteBinder); tryMethodMap.put(tryExpression, tryMethod); } else if (lambdaOrTryExpression instanceof LambdaDefinitionExpression) { LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) lambdaOrTryExpression; PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); LambdaExpressionField lambdaExpressionField = LambdaBytecodeGenerator.preGenerateLambdaExpression( lambdaExpression, "lambda_" + counter, containerClassDefinition, preGeneratedExpressions, callSiteBinder, cachedInstanceBinder, metadata.getFunctionRegistry()); lambdaFieldMap.put(lambdaExpression, lambdaExpressionField); } else { throw new VerifyException(format("unexpected expression: %s", lambdaOrTryExpression.toString())); } counter++; } return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); } private static void generateConstructor( ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, PreGeneratedExpressions preGeneratedExpressions, Consumer<MethodDefinition> additionalStatements) { MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); BytecodeBlock body = constructorDefinition.getBody(); Variable thisVariable = constructorDefinition.getThis(); body.comment("super();") .append(thisVariable) .invokeConstructor(Object.class); additionalStatements.accept(constructorDefinition); cachedInstanceBinder.generateInitializations(thisVariable, body); for (LambdaExpressionField field : preGeneratedExpressions.getLambdaFieldMap().values()) { field.generateInitialization(thisVariable, body); } body.ret(); } private static void declareBlockVariables(RowExpression expression, Parameter page, Scope scope, BytecodeBlock body) { for (int channel : getInputChannels(expression)) { scope.declareVariable("block_" + channel, body, page.invoke("getBlock", Block.class, constantInt(channel))); } } private static List<Integer> getInputChannels(Iterable<RowExpression> expressions) { TreeSet<Integer> channels = new TreeSet<>(); for (RowExpression expression : Expressions.subExpressions(expressions)) { if (expression instanceof InputReferenceExpression) { channels.add(((InputReferenceExpression) expression).getField()); } } return ImmutableList.copyOf(channels); } private static List<Integer> getInputChannels(RowExpression expression) { return getInputChannels(ImmutableList.of(expression)); } private static List<Parameter> toBlockParameters(List<Integer> inputChannels) { ImmutableList.Builder<Parameter> parameters = ImmutableList.builder(); for (int channel : inputChannels) { parameters.add(arg("block_" + channel, Block.class)); } return parameters.build(); } private static RowExpressionVisitor<Scope, BytecodeNode> fieldReferenceCompiler(CallSiteBinder callSiteBinder) { return new InputReferenceCompiler( (scope, field) -> scope.getVariable("block_" + field), (scope, field) -> scope.getVariable("position"), callSiteBinder); } }