/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ErrorCodeSupplier; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.util.List; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.divide; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; public final class MapTransformKeyFunction extends SqlScalarFunction { public static final MapTransformKeyFunction MAP_TRANSFORM_KEY_FUNCTION = new MapTransformKeyFunction(); private MapTransformKeyFunction() { super(new Signature( "transform_keys", FunctionKind.SCALAR, ImmutableList.of(typeVariable("K1"), typeVariable("K2"), typeVariable("V")), ImmutableList.of(), parseTypeSignature("map(K2,V)"), ImmutableList.of(parseTypeSignature("map(K1,V)"), parseTypeSignature("function(K1,V,K2)")), false)); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return false; } @Override public String getDescription() { return "apply lambda to each entry of the map and transform the key"; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type keyType = boundVariables.getTypeVariable("K1"); Type transformedKeyType = boundVariables.getTypeVariable("K2"); Type valueType = boundVariables.getTypeVariable("V"); Type resultMapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( TypeSignatureParameter.of(transformedKeyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), generateTransformKey(keyType, transformedKeyType, valueType, resultMapType), isDeterministic()); } private static MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType, Type resultMapType) { CallSiteBinder binder = new CallSiteBinder(); Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType()); Class<?> transformedKeyJavaType = Primitives.wrap(transformedKeyType.getJavaType()); Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("MapTransformKey"), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); Parameter session = arg("session", ConnectorSession.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", MethodHandle.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "transform", type(Block.class), ImmutableList.of(session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable typedSet = scope.declareVariable(TypedSet.class, "typeSet"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); Variable transformedKeyElement = scope.declareVariable(transformedKeyJavaType, "transformedKeyElement"); Variable valueElement = scope.declareVariable(valueJavaType, "valueElement"); // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); // create the interleaved block builder body.append(blockBuilder.set(newInstance( InterleavedBlockBuilder.class, constantType(binder, resultMapType).invoke("getTypeParameters", List.class), newInstance(BlockBuilderStatus.class), positionCount))); // create typed set body.append(typedSet.set(newInstance( TypedSet.class, constantType(binder, transformedKeyType), divide(positionCount, constantInt(2))))); // throw null key exception block BytecodeNode throwNullKeyException = new BytecodeBlock() .append(newInstance( PrestoException.class, getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), constantString("map key cannot be null"))) .throwObject(); SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType); BytecodeNode loadKeyElement; if (!keyType.equals(UNKNOWN)) { loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType))); } else { // make sure invokeExact will not take uninitialized keys during compile time // but if we reach this point during runtime, it is an exception loadKeyElement = new BytecodeBlock() .append(keyElement.set(constantNull(keyJavaType))) .append(throwNullKeyException); } SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType); BytecodeNode loadValueElement; if (!valueType.equals(UNKNOWN)) { loadValueElement = new IfStatement() .condition(block.invoke("isNull", boolean.class, add(position, constantInt(1)))) .ifTrue(valueElement.set(constantNull(valueJavaType))) .ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType))); } else { // make sure invokeExact will not take uninitialized keys during compile time loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType))); } SqlTypeBytecodeExpression transformedKeySqlType = constantType(binder, transformedKeyType); BytecodeNode writeKeyElement; BytecodeNode throwDuplicatedKeyException; if (!transformedKeyType.equals(UNKNOWN)) { writeKeyElement = new BytecodeBlock() .append(transformedKeyElement.set(function.invoke("invokeExact", transformedKeyJavaType, session, keyElement, valueElement))) .append(new IfStatement() .condition(equal(transformedKeyElement, constantNull(transformedKeyJavaType))) .ifTrue(throwNullKeyException) .ifFalse(new BytecodeBlock() .append(constantType(binder, transformedKeyType).writeValue(blockBuilder, transformedKeyElement.cast(transformedKeyType.getJavaType()))) .append(valueSqlType.invoke("appendTo", void.class, block, add(position, constantInt(1)), blockBuilder)))); // make sure getObjectValue takes a known key type throwDuplicatedKeyException = new BytecodeBlock() .append(newInstance( PrestoException.class, getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), invokeStatic( String.class, "format", String.class, constantString("Duplicate keys (%s) are not allowed"), newArray(type(Object[].class), ImmutableList.of(transformedKeySqlType.invoke("getObjectValue", Object.class, session, blockBuilder.cast(Block.class), position)))))) .throwObject(); } else { // key cannot be unknown // if we reach this point during runtime, it is an exception writeKeyElement = throwNullKeyException; throwDuplicatedKeyException = throwNullKeyException; } body.append(new ForLoop() .initialize(position.set(constantInt(0))) .condition(lessThan(position, positionCount)) .update(incrementVariable(position, (byte) 2)) .body(new BytecodeBlock() .append(loadKeyElement) .append(loadValueElement) .append(writeKeyElement) .append(new IfStatement() .condition(typedSet.invoke("contains", boolean.class, blockBuilder.cast(Block.class), position)) .ifTrue(throwDuplicatedKeyException) .ifFalse(typedSet.invoke("add", void.class, blockBuilder.cast(Block.class), position))))); body.append(blockBuilder.invoke("build", Block.class).ret()); Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformKeyFunction.class.getClassLoader()); return methodHandle(generatedClass, "transform", ConnectorSession.class, Block.class, MethodHandle.class); } }