/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.control.LookupSwitch;
import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementation;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DateType;
import com.facebook.presto.spi.type.IntegerType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.util.FastutilSetHelper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.facebook.presto.bytecode.control.LookupSwitch.lookupSwitchBuilder;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue;
import static com.facebook.presto.bytecode.instruction.JumpInstruction.jump;
import static com.facebook.presto.metadata.Signature.internalOperator;
import static com.facebook.presto.spi.function.OperatorType.EQUAL;
import static com.facebook.presto.spi.function.OperatorType.HASH_CODE;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto;
import static com.facebook.presto.sql.gen.BytecodeUtils.invoke;
import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant;
import static com.facebook.presto.util.FastutilSetHelper.toFastutilHashSet;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
public class InCodeGenerator
implements BytecodeGenerator
{
private final FunctionRegistry registry;
public InCodeGenerator(FunctionRegistry registry)
{
this.registry = requireNonNull(registry, "registry is null");
}
enum SwitchGenerationCase
{
DIRECT_SWITCH,
HASH_SWITCH,
SET_CONTAINS
}
@VisibleForTesting
static SwitchGenerationCase checkSwitchGenerationCase(Type type, List<RowExpression> values)
{
if (values.size() > 32) {
// 32 is chosen because
// * SET_CONTAINS performs worst when smaller than but close to power of 2
// * Benchmark shows performance of SET_CONTAINS is better at 50, but similar at 25.
return SwitchGenerationCase.SET_CONTAINS;
}
if (!(type instanceof IntegerType || type instanceof BigintType || type instanceof DateType)) {
return SwitchGenerationCase.HASH_SWITCH;
}
for (RowExpression expression : values) {
// For non-constant expressions, they will be added to the default case in the generated switch code. They do not affect any of
// the cases other than the default one. Therefore, it's okay to skip them when choosing between DIRECT_SWITCH and HASH_SWITCH.
// Same argument applies for nulls.
if (!(expression instanceof ConstantExpression)) {
continue;
}
Object constant = ((ConstantExpression) expression).getValue();
if (constant == null) {
continue;
}
long longConstant = ((Number) constant).longValue();
if (longConstant < Integer.MIN_VALUE || longConstant > Integer.MAX_VALUE) {
return SwitchGenerationCase.HASH_SWITCH;
}
}
return SwitchGenerationCase.DIRECT_SWITCH;
}
@Override
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List<RowExpression> arguments)
{
BytecodeNode value = generatorContext.generate(arguments.get(0));
List<RowExpression> values = arguments.subList(1, arguments.size());
ImmutableList.Builder<BytecodeNode> valuesBytecode = ImmutableList.builder();
for (int i = 1; i < arguments.size(); i++) {
BytecodeNode testNode = generatorContext.generate(arguments.get(i));
valuesBytecode.add(testNode);
}
Type type = arguments.get(0).getType();
Class<?> javaType = type.getJavaType();
SwitchGenerationCase switchGenerationCase = checkSwitchGenerationCase(type, values);
Signature hashCodeSignature = internalOperator(HASH_CODE, BIGINT, ImmutableList.of(type));
MethodHandle hashCodeFunction = generatorContext.getRegistry().getScalarFunctionImplementation(hashCodeSignature).getMethodHandle();
ImmutableListMultimap.Builder<Integer, BytecodeNode> hashBucketsBuilder = ImmutableListMultimap.builder();
ImmutableList.Builder<BytecodeNode> defaultBucket = ImmutableList.builder();
ImmutableSet.Builder<Object> constantValuesBuilder = ImmutableSet.builder();
for (RowExpression testValue : values) {
BytecodeNode testBytecode = generatorContext.generate(testValue);
if (testValue instanceof ConstantExpression && ((ConstantExpression) testValue).getValue() != null) {
ConstantExpression constant = (ConstantExpression) testValue;
Object object = constant.getValue();
switch (switchGenerationCase) {
case DIRECT_SWITCH:
case SET_CONTAINS:
constantValuesBuilder.add(object);
break;
case HASH_SWITCH:
try {
int hashCode = toIntExact(Long.hashCode((Long) hashCodeFunction.invoke(object)));
hashBucketsBuilder.put(hashCode, testBytecode);
}
catch (Throwable throwable) {
throw new IllegalArgumentException("Error processing IN statement: error calculating hash code for " + object, throwable);
}
break;
default:
throw new IllegalArgumentException("Not supported switch generation case: " + switchGenerationCase);
}
}
else {
defaultBucket.add(testBytecode);
}
}
ImmutableListMultimap<Integer, BytecodeNode> hashBuckets = hashBucketsBuilder.build();
ImmutableSet<Object> constantValues = constantValuesBuilder.build();
LabelNode end = new LabelNode("end");
LabelNode match = new LabelNode("match");
LabelNode noMatch = new LabelNode("noMatch");
LabelNode defaultLabel = new LabelNode("default");
Scope scope = generatorContext.getScope();
BytecodeNode switchBlock;
BytecodeBlock switchCaseBlocks = new BytecodeBlock();
LookupSwitch.LookupSwitchBuilder switchBuilder = lookupSwitchBuilder();
switch (switchGenerationCase) {
case DIRECT_SWITCH:
// A white-list is used to select types eligible for DIRECT_SWITCH.
// For these types, it's safe to not use presto HASH_CODE and EQUAL operator.
for (Object constantValue : constantValues) {
switchBuilder.addCase(toIntExact((Long) constantValue), match);
}
switchBuilder.defaultCase(defaultLabel);
switchBlock = new BytecodeBlock()
.comment("lookupSwitch(<stackValue>))")
.dup(javaType)
.append(new IfStatement()
.condition(new BytecodeBlock()
.dup(javaType)
.invokeStatic(InCodeGenerator.class, "isInteger", boolean.class, long.class))
.ifFalse(new BytecodeBlock()
.pop(javaType)
.gotoLabel(defaultLabel)))
.longToInt()
.append(switchBuilder.build());
break;
case HASH_SWITCH:
for (Map.Entry<Integer, Collection<BytecodeNode>> bucket : hashBuckets.asMap().entrySet()) {
LabelNode label = new LabelNode("inHash" + bucket.getKey());
switchBuilder.addCase(bucket.getKey(), label);
Collection<BytecodeNode> testValues = bucket.getValue();
BytecodeBlock caseBlock = buildInCase(generatorContext, scope, type, label, match, defaultLabel, testValues, false);
switchCaseBlocks.append(caseBlock.setDescription("case " + bucket.getKey()));
}
switchBuilder.defaultCase(defaultLabel);
Binding hashCodeBinding = generatorContext
.getCallSiteBinder()
.bind(hashCodeFunction);
switchBlock = new BytecodeBlock()
.comment("lookupSwitch(hashCode(<stackValue>))")
.dup(javaType)
.append(invoke(hashCodeBinding, hashCodeSignature))
.invokeStatic(Long.class, "hashCode", int.class, long.class)
.append(switchBuilder.build())
.append(switchCaseBlocks);
break;
case SET_CONTAINS:
Set<?> constantValuesSet = toFastutilHashSet(constantValues, type, registry);
Binding constant = generatorContext.getCallSiteBinder().bind(constantValuesSet, constantValuesSet.getClass());
switchBlock = new BytecodeBlock()
.comment("inListSet.contains(<stackValue>)")
.append(new IfStatement()
.condition(new BytecodeBlock()
.comment("value")
.dup(javaType)
.comment("set")
.append(loadConstant(constant))
// TODO: use invokeVirtual on the set instead. This requires swapping the two elements in the stack
.invokeStatic(FastutilSetHelper.class, "in", boolean.class, javaType.isPrimitive() ? javaType : Object.class, constantValuesSet.getClass()))
.ifTrue(jump(match)));
break;
default:
throw new IllegalArgumentException("Not supported switch generation case: " + switchGenerationCase);
}
BytecodeBlock defaultCaseBlock = buildInCase(generatorContext, scope, type, defaultLabel, match, noMatch, defaultBucket.build(), true).setDescription("default");
BytecodeBlock block = new BytecodeBlock()
.comment("IN")
.append(value)
.append(ifWasNullPopAndGoto(scope, end, boolean.class, javaType))
.append(switchBlock)
.append(defaultCaseBlock);
BytecodeBlock matchBlock = new BytecodeBlock()
.setDescription("match")
.visitLabel(match)
.pop(javaType)
.append(generatorContext.wasNull().set(constantFalse()))
.push(true)
.gotoLabel(end);
block.append(matchBlock);
BytecodeBlock noMatchBlock = new BytecodeBlock()
.setDescription("noMatch")
.visitLabel(noMatch)
.pop(javaType)
.push(false)
.gotoLabel(end);
block.append(noMatchBlock);
block.visitLabel(end);
return block;
}
public static boolean isInteger(long value)
{
return value == (int) value;
}
private static BytecodeBlock buildInCase(BytecodeGeneratorContext generatorContext,
Scope scope,
Type type,
LabelNode caseLabel,
LabelNode matchLabel,
LabelNode noMatchLabel,
Collection<BytecodeNode> testValues,
boolean checkForNulls)
{
Variable caseWasNull = null; // caseWasNull is set to true the first time a null in `testValues` is encountered
if (checkForNulls) {
caseWasNull = scope.createTempVariable(boolean.class);
}
BytecodeBlock caseBlock = new BytecodeBlock()
.visitLabel(caseLabel);
if (checkForNulls) {
caseBlock.putVariable(caseWasNull, false);
}
LabelNode elseLabel = new LabelNode("else");
BytecodeBlock elseBlock = new BytecodeBlock()
.visitLabel(elseLabel);
Variable wasNull = generatorContext.wasNull();
if (checkForNulls) {
elseBlock.append(wasNull.set(caseWasNull));
}
elseBlock.gotoLabel(noMatchLabel);
ScalarFunctionImplementation operator = generatorContext.getRegistry().getScalarFunctionImplementation(internalOperator(EQUAL, BOOLEAN, ImmutableList.of(type, type)));
Binding equalsFunction = generatorContext
.getCallSiteBinder()
.bind(operator.getMethodHandle());
BytecodeNode elseNode = elseBlock;
for (BytecodeNode testNode : testValues) {
LabelNode testLabel = new LabelNode("test");
IfStatement test = new IfStatement();
test.condition()
.visitLabel(testLabel)
.dup(type.getJavaType())
.append(testNode);
if (checkForNulls) {
IfStatement wasNullCheck = new IfStatement("if wasNull, set caseWasNull to true, clear wasNull, pop 2 values of type, and goto next test value");
wasNullCheck.condition(wasNull);
wasNullCheck.ifTrue(new BytecodeBlock()
.append(caseWasNull.set(constantTrue()))
.append(wasNull.set(constantFalse()))
.pop(type.getJavaType())
.pop(type.getJavaType())
.gotoLabel(elseLabel));
test.condition().append(wasNullCheck);
}
test.condition()
.append(invoke(equalsFunction, EQUAL.name()));
test.ifTrue().gotoLabel(matchLabel);
test.ifFalse(elseNode);
elseNode = test;
elseLabel = testLabel;
}
caseBlock.append(elseNode);
return caseBlock;
}
}