/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ErrorCodeSupplier; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.util.List; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; public final class MapTransformValueFunction extends SqlScalarFunction { public static final MapTransformValueFunction MAP_TRANSFORM_VALUE_FUNCTION = new MapTransformValueFunction(); private MapTransformValueFunction() { super(new Signature( "transform_values", FunctionKind.SCALAR, ImmutableList.of(typeVariable("K"), typeVariable("V1"), typeVariable("V2")), ImmutableList.of(), parseTypeSignature("map(K,V2)"), ImmutableList.of(parseTypeSignature("map(K,V1)"), parseTypeSignature("function(K,V1,V2)")), false)); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return false; } @Override public String getDescription() { return "apply lambda to each entry of the map and transform the value"; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V1"); Type transformedValueType = boundVariables.getTypeVariable("V2"); Type resultMapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(transformedValueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), generateTransform(keyType, valueType, transformedValueType, resultMapType), isDeterministic()); } private static MethodHandle generateTransform(Type keyType, Type valueType, Type transformedValueType, Type resultMapType) { CallSiteBinder binder = new CallSiteBinder(); Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType()); Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType()); Class<?> transformedValueJavaType = Primitives.wrap(transformedValueType.getJavaType()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("MapTransformValue"), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); // define transform method Parameter session = arg("session", ConnectorSession.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", MethodHandle.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "transform", type(Block.class), ImmutableList.of(session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); Variable valueElement = scope.declareVariable(valueJavaType, "valueElement"); Variable transformedValueElement = scope.declareVariable(transformedValueJavaType, "transformedValueElement"); // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); // create the interleaved block builder body.append(blockBuilder.set(newInstance( InterleavedBlockBuilder.class, constantType(binder, resultMapType).invoke("getTypeParameters", List.class), newInstance(BlockBuilderStatus.class), positionCount))); // throw null key exception block BytecodeNode throwNullKeyException = new BytecodeBlock() .append(newInstance( PrestoException.class, getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), constantString("map key cannot be null"))) .throwObject(); SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType); BytecodeNode loadKeyElement; if (!keyType.equals(UNKNOWN)) { loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType))); } else { // make sure invokeExact will not take uninitialized keys during compile time // but if we reach this point during runtime, it is an exception loadKeyElement = new BytecodeBlock() .append(keyElement.set(constantNull(keyJavaType))) .append(throwNullKeyException); } SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType); BytecodeNode loadValueElement; if (!valueType.equals(UNKNOWN)) { loadValueElement = new IfStatement() .condition(block.invoke("isNull", boolean.class, add(position, constantInt(1)))) .ifTrue(valueElement.set(constantNull(valueJavaType))) .ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType))); } else { loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType))); } BytecodeNode writeTransformedValueElement; if (!transformedValueType.equals(UNKNOWN)) { writeTransformedValueElement = new IfStatement() .condition(equal(transformedValueElement, constantNull(transformedValueJavaType))) .ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()) .ifFalse(constantType(binder, transformedValueType).writeValue(blockBuilder, transformedValueElement.cast(transformedValueType.getJavaType()))); } else { writeTransformedValueElement = new BytecodeBlock().append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()); } body.append(new ForLoop() .initialize(position.set(constantInt(0))) .condition(lessThan(position, positionCount)) .update(incrementVariable(position, (byte) 2)) .body(new BytecodeBlock() .append(loadKeyElement) .append(loadValueElement) .append(transformedValueElement.set(function.invoke("invokeExact", transformedValueJavaType, session, keyElement, valueElement))) .append(keySqlType.invoke("appendTo", void.class, block, position, blockBuilder)) .append(writeTransformedValueElement))); body.append(blockBuilder.invoke("build", Block.class).ret()); Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformValueFunction.class.getClassLoader()); return methodHandle(generatedClass, "transform", ConnectorSession.class, Block.class, MethodHandle.class); } }