/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; public final class ArrayTransformFunction extends SqlScalarFunction { public static final ArrayTransformFunction ARRAY_TRANSFORM_FUNCTION = new ArrayTransformFunction(); private ArrayTransformFunction() { super(new Signature( "transform", FunctionKind.SCALAR, ImmutableList.of(typeVariable("T"), typeVariable("U")), ImmutableList.of(), parseTypeSignature("array(U)"), ImmutableList.of(parseTypeSignature("array(T)"), parseTypeSignature("function(T,U)")), false)); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return false; } @Override public String getDescription() { return "apply lambda to each element of the array"; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type inputType = boundVariables.getTypeVariable("T"); Type outputType = boundVariables.getTypeVariable("U"); Class<?> generatedClass = generateTransform(inputType, outputType); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), ImmutableList.of(false, false), methodHandle(generatedClass, "transform", PageBuilder.class, ConnectorSession.class, Block.class, MethodHandle.class), Optional.of(methodHandle(generatedClass, "createPageBuilder")), isDeterministic()); } private static Class<?> generateTransform(Type inputType, Type outputType) { CallSiteBinder binder = new CallSiteBinder(); Class<?> inputJavaType = Primitives.wrap(inputType.getJavaType()); Class<?> outputJavaType = Primitives.wrap(outputType.getJavaType()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("ArrayTransform"), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); // define createPageBuilder MethodDefinition createPageBuilderMethod = definition.declareMethod(a(PUBLIC, STATIC), "createPageBuilder", type(PageBuilder.class)); createPageBuilderMethod.getBody() .append(newInstance(PageBuilder.class, constantType(binder, new ArrayType(outputType)).invoke("getTypeParameters", List.class)).ret()); // define transform method Parameter session = arg("session", ConnectorSession.class); Parameter pageBuilder = arg("pageBuilder", PageBuilder.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", MethodHandle.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "transform", type(Block.class), ImmutableList.of(pageBuilder, session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable inputElement = scope.declareVariable(inputJavaType, "inputElement"); Variable outputElement = scope.declareVariable(outputJavaType, "outputElement"); // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); // reset page builder if it is full body.append(new IfStatement() .condition(pageBuilder.invoke("isFull", boolean.class)) .ifTrue(pageBuilder.invoke("reset", void.class))); // get block builder body.append(blockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0)))); BytecodeNode loadInputElement; if (!inputType.equals(UNKNOWN)) { loadInputElement = new IfStatement() .condition(block.invoke("isNull", boolean.class, position)) .ifTrue(inputElement.set(constantNull(inputJavaType))) .ifFalse(inputElement.set(constantType(binder, inputType).getValue(block, position).cast(inputJavaType))); } else { loadInputElement = new BytecodeBlock().append(inputElement.set(constantNull(inputJavaType))); } BytecodeNode writeOutputElement; if (!outputType.equals(UNKNOWN)) { writeOutputElement = new IfStatement() .condition(equal(outputElement, constantNull(outputJavaType))) .ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()) .ifFalse(constantType(binder, outputType).writeValue(blockBuilder, outputElement.cast(outputType.getJavaType()))); } else { writeOutputElement = new BytecodeBlock().append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()); } body.append(new ForLoop() .initialize(position.set(constantInt(0))) .condition(lessThan(position, positionCount)) .update(incrementVariable(position, (byte) 1)) .body(new BytecodeBlock() .append(loadInputElement) .append(outputElement.set(function.invoke("invokeExact", outputJavaType, session, inputElement))) .append(writeOutputElement))); body.append(pageBuilder.invoke("declarePositions", void.class, positionCount)); body.append(blockBuilder.invoke("getRegion", Block.class, subtract(blockBuilder.invoke("getPositionCount", int.class), positionCount), positionCount).ret()); return defineClass(definition, Object.class, binder.getBindings(), ArrayTransformFunction.class.getClassLoader()); } }