/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.util.List; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.and; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.notEqual; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; public final class MapFilterFunction extends SqlScalarFunction { public static final MapFilterFunction MAP_FILTER_FUNCTION = new MapFilterFunction(); private MapFilterFunction() { super(new Signature( "map_filter", FunctionKind.SCALAR, ImmutableList.of(typeVariable("K"), typeVariable("V")), ImmutableList.of(), parseTypeSignature("map(K,V)"), ImmutableList.of(parseTypeSignature("map(K,V)"), parseTypeSignature("function(K,V,boolean)")), false)); } @Override public boolean isHidden() { return false; } @Override public boolean isDeterministic() { return false; } @Override public String getDescription() { return "return map containing entries that match the given predicate"; } @Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), generateFilter(keyType, valueType, mapType), isDeterministic()); } private static MethodHandle generateFilter(Type keyType, Type valueType, Type mapType) { CallSiteBinder binder = new CallSiteBinder(); Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType()); Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("MapFilter"), type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); Parameter session = arg("session", ConnectorSession.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", MethodHandle.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "filter", type(Block.class), ImmutableList.of(session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); Variable valueElement = scope.declareVariable(valueJavaType, "valueElement"); Variable keep = scope.declareVariable(Boolean.class, "keep"); // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); // create the interleaved block builder body.append(blockBuilder.set(newInstance( InterleavedBlockBuilder.class, constantType(binder, mapType).invoke("getTypeParameters", List.class), newInstance(BlockBuilderStatus.class), positionCount))); SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType); BytecodeNode loadKeyElement; if (!keyType.equals(UNKNOWN)) { // key element must be non-null loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType))); } else { loadKeyElement = new BytecodeBlock().append(keyElement.set(constantNull(keyJavaType))); } SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType); BytecodeNode loadValueElement; if (!valueType.equals(UNKNOWN)) { loadValueElement = new IfStatement() .condition(block.invoke("isNull", boolean.class, add(position, constantInt(1)))) .ifTrue(valueElement.set(constantNull(valueJavaType))) .ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType))); } else { loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType))); } body.append(new ForLoop() .initialize(position.set(constantInt(0))) .condition(lessThan(position, positionCount)) .update(incrementVariable(position, (byte) 2)) .body(new BytecodeBlock() .append(loadKeyElement) .append(loadValueElement) .append(keep.set(function.invoke("invokeExact", Boolean.class, session, keyElement, valueElement))) .append(new IfStatement("if (keep != null && keep) ...") .condition(and(notEqual(keep, constantNull(Boolean.class)), keep.cast(boolean.class))) .ifTrue(new BytecodeBlock() .append(keySqlType.invoke("appendTo", void.class, block, position, blockBuilder)) .append(valueSqlType.invoke("appendTo", void.class, block, add(position, constantInt(1)), blockBuilder)))))); body.append(blockBuilder.invoke("build", Block.class).ret()); Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapFilterFunction.class.getClassLoader()); return methodHandle(generatedClass, "filter", ConnectorSession.class, Block.class, MethodHandle.class); } }