/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.InternalJoinFilterFunction;
import com.facebook.presto.operator.JoinFilterFunction;
import com.facebook.presto.operator.StandardJoinFilterFunction;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.LambdaExpressionField;
import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression;
import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.Signatures;
import com.google.common.base.Throwables;
import com.google.common.base.VerifyException;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Primitives;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;
import javax.inject.Inject;
import java.lang.reflect.Constructor;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.CompilerUtils.defineClass;
import static com.facebook.presto.bytecode.CompilerUtils.makeClassName;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.sql.gen.BytecodeUtils.invoke;
import static com.facebook.presto.sql.gen.LambdaAndTryExpressionExtractor.extractLambdaAndTryExpressions;
import static com.facebook.presto.sql.gen.TryCodeGenerator.defineTryMethod;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class JoinFilterFunctionCompiler
{
private final Metadata metadata;
@Inject
public JoinFilterFunctionCompiler(Metadata metadata)
{
this.metadata = metadata;
}
private final LoadingCache<JoinFilterCacheKey, JoinFilterFunctionFactory> joinFilterFunctionFactories = CacheBuilder.newBuilder()
.recordStats()
.maximumSize(1000)
.build(new CacheLoader<JoinFilterCacheKey, JoinFilterFunctionFactory>()
{
@Override
public JoinFilterFunctionFactory load(JoinFilterCacheKey key)
throws Exception
{
return internalCompileFilterFunctionFactory(key.getFilter(), key.getLeftBlocksSize(), key.getSortChannel());
}
});
@Managed
@Nested
public CacheStatsMBean getJoinFilterFunctionFactoryStats()
{
return new CacheStatsMBean(joinFilterFunctionFactories);
}
public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize, Optional<SortExpression> sortChannel)
{
return joinFilterFunctionFactories.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize, sortChannel));
}
private JoinFilterFunctionFactory internalCompileFilterFunctionFactory(RowExpression filterExpression, int leftBlocksSize, Optional<SortExpression> sortChannel)
{
Class<? extends InternalJoinFilterFunction> internalJoinFilterFunction = compileInternalJoinFilterFunction(filterExpression, leftBlocksSize);
return new IsolatedJoinFilterFunctionFactory(internalJoinFilterFunction, sortChannel);
}
private Class<? extends InternalJoinFilterFunction> compileInternalJoinFilterFunction(RowExpression filterExpression, int leftBlocksSize)
{
ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName("JoinFilterFunction"),
type(Object.class),
type(InternalJoinFilterFunction.class));
CallSiteBinder callSiteBinder = new CallSiteBinder();
new JoinFilterFunctionCompiler(metadata).generateMethods(classDefinition, callSiteBinder, filterExpression, leftBlocksSize);
//
// toString method
//
generateToString(
classDefinition,
callSiteBinder,
toStringHelper(classDefinition.getType().getJavaClassName())
.add("filter", filterExpression)
.add("leftBlocksSize", leftBlocksSize)
.toString());
return defineClass(classDefinition, InternalJoinFilterFunction.class, callSiteBinder.getBindings(), getClass().getClassLoader());
}
private void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, int leftBlocksSize)
{
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class);
PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, leftBlocksSize, filter);
generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, filter, leftBlocksSize, sessionField);
generateConstructor(classDefinition, sessionField, cachedInstanceBinder, preGeneratedExpressions);
}
private static void generateConstructor(
ClassDefinition classDefinition,
FieldDefinition sessionField,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions)
{
Parameter sessionParameter = arg("session", ConnectorSession.class);
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), sessionParameter);
BytecodeBlock body = constructorDefinition.getBody();
Variable thisVariable = constructorDefinition.getThis();
body.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);
body.append(thisVariable.setField(sessionField, sessionParameter));
cachedInstanceBinder.generateInitializations(thisVariable, body);
for (LambdaExpressionField field : preGeneratedExpressions.getLambdaFieldMap().values()) {
field.generateInitialization(thisVariable, body);
}
body.ret();
}
private void generateFilterMethod(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions,
RowExpression filter,
int leftBlocksSize,
FieldDefinition sessionField)
{
// int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks
Parameter leftPosition = arg("leftPosition", int.class);
Parameter leftBlocks = arg("leftBlocks", Block[].class);
Parameter rightPosition = arg("rightPosition", int.class);
Parameter rightBlocks = arg("rightBlocks", Block[].class);
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"filter",
type(boolean.class),
ImmutableList.<Parameter>builder()
.add(leftPosition)
.add(leftBlocks)
.add(rightPosition)
.add(rightBlocks)
.build());
method.comment("filter: %s", filter.toString());
BytecodeBlock body = method.getBody();
Scope scope = method.getScope();
Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
scope.declareVariable("session", body, method.getThis().getField(sessionField));
BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor(
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder, leftPosition, leftBlocks, rightPosition, rightBlocks, leftBlocksSize),
metadata.getFunctionRegistry(),
preGeneratedExpressions);
BytecodeNode visitorBody = filter.accept(visitor, scope);
Variable result = scope.declareVariable(boolean.class, "result");
body.append(visitorBody)
.putVariable(result)
.append(new IfStatement()
.condition(wasNullVariable)
.ifTrue(constantFalse().ret())
.ifFalse(result.ret()));
}
private PreGeneratedExpressions generateMethodsForLambdaAndTry(
ClassDefinition containerClassDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
int leftBlocksSize,
RowExpression filter)
{
Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(filter));
ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder();
ImmutableMap.Builder<LambdaDefinitionExpression, LambdaExpressionField> lambdaFieldMap = ImmutableMap.builder();
int counter = 0;
for (RowExpression expression : lambdaAndTryExpressions) {
if (expression instanceof CallExpression) {
CallExpression tryExpression = (CallExpression) expression;
verify(!Signatures.TRY.equals(tryExpression.getSignature().getName()));
Parameter session = arg("session", ConnectorSession.class);
Parameter leftPosition = arg("leftPosition", int.class);
Parameter leftBlocks = arg("leftBlocks", Block[].class);
Parameter rightPosition = arg("rightPosition", int.class);
Parameter rightBlocks = arg("rightBlocks", Block[].class);
BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor(
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder, leftPosition, leftBlocks, rightPosition, rightBlocks, leftBlocksSize),
metadata.getFunctionRegistry(),
new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()));
List<Parameter> inputParameters = ImmutableList.<Parameter>builder()
.add(session)
.add(leftPosition)
.add(leftBlocks)
.add(rightPosition)
.add(rightBlocks)
.build();
MethodDefinition tryMethod = defineTryMethod(
innerExpressionVisitor,
containerClassDefinition,
"try_" + counter,
inputParameters,
Primitives.wrap(tryExpression.getType().getJavaType()),
tryExpression,
callSiteBinder);
tryMethodMap.put(tryExpression, tryMethod);
}
else if (expression instanceof LambdaDefinitionExpression) {
LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression;
PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build());
LambdaExpressionField lambdaExpressionField = LambdaBytecodeGenerator.preGenerateLambdaExpression(
lambdaExpression,
"lambda_" + counter,
containerClassDefinition,
preGeneratedExpressions,
callSiteBinder,
cachedInstanceBinder,
metadata.getFunctionRegistry());
lambdaFieldMap.put(lambdaExpression, lambdaExpressionField);
}
else {
throw new VerifyException(format("unexpected expression: %s", expression.toString()));
}
counter++;
}
return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build());
}
private static void generateToString(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, String string)
{
// bind constant via invokedynamic to avoid constant pool issues due to large strings
classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class))
.getBody()
.append(invoke(callSiteBinder.bind(string, String.class), "toString"))
.retObject();
}
public interface JoinFilterFunctionFactory
{
JoinFilterFunction create(ConnectorSession session, LongArrayList addresses, List<List<Block>> channels);
default Optional<SortExpression> getSortChannel()
{
return Optional.empty();
}
}
private static RowExpressionVisitor<Scope, BytecodeNode> fieldReferenceCompiler(
final CallSiteBinder callSiteBinder,
final Variable leftPosition,
final Variable leftBlocks,
final Variable rightPosition,
final Variable rightBlocks,
final int leftBlocksSize)
{
return new InputReferenceCompiler(
(scope, field) -> field < leftBlocksSize ? leftBlocks.getElement(field) : rightBlocks.getElement(field - leftBlocksSize),
(scope, field) -> field < leftBlocksSize ? leftPosition : rightPosition,
callSiteBinder);
}
private static final class JoinFilterCacheKey
{
private final RowExpression filter;
private final int leftBlocksSize;
private final Optional<SortExpression> sortChannel;
public JoinFilterCacheKey(RowExpression filter, int leftBlocksSize, Optional<SortExpression> sortChannel)
{
this.filter = requireNonNull(filter, "filter can not be null");
this.leftBlocksSize = leftBlocksSize;
this.sortChannel = requireNonNull(sortChannel, "sortChannel can not be null");
}
public RowExpression getFilter()
{
return filter;
}
public int getLeftBlocksSize()
{
return leftBlocksSize;
}
public Optional<SortExpression> getSortChannel()
{
return sortChannel;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JoinFilterCacheKey that = (JoinFilterCacheKey) o;
return leftBlocksSize == that.leftBlocksSize &&
Objects.equals(filter, that.filter);
}
@Override
public int hashCode()
{
return Objects.hash(filter, leftBlocksSize);
}
@Override
public String toString()
{
return toStringHelper(this)
.add("filter", filter)
.add("leftBlocksSize", leftBlocksSize)
.toString();
}
}
private static class IsolatedJoinFilterFunctionFactory
implements JoinFilterFunctionFactory
{
private final Constructor<? extends InternalJoinFilterFunction> internalJoinFilterFunctionConstructor;
private final Constructor<? extends JoinFilterFunction> isolatedJoinFilterFunctionConstructor;
private final Optional<SortExpression> sortChannel;
public IsolatedJoinFilterFunctionFactory(Class<? extends InternalJoinFilterFunction> internalJoinFilterFunction, Optional<SortExpression> sortChannel)
{
this.sortChannel = sortChannel;
try {
internalJoinFilterFunctionConstructor = internalJoinFilterFunction
.getConstructor(ConnectorSession.class);
Class<? extends JoinFilterFunction> isolatedJoinFilterFunction = IsolatedClass.isolateClass(
new DynamicClassLoader(getClass().getClassLoader()),
JoinFilterFunction.class,
StandardJoinFilterFunction.class);
isolatedJoinFilterFunctionConstructor = isolatedJoinFilterFunction.getConstructor(InternalJoinFilterFunction.class, LongArrayList.class, List.class, Optional.class);
}
catch (NoSuchMethodException e) {
throw Throwables.propagate(e);
}
}
@Override
public JoinFilterFunction create(ConnectorSession session, LongArrayList addresses, List<List<Block>> channels)
{
try {
InternalJoinFilterFunction internalJoinFilterFunction = internalJoinFilterFunctionConstructor.newInstance(session);
return isolatedJoinFilterFunctionConstructor.newInstance(internalJoinFilterFunction, addresses, channels, sortChannel);
}
catch (ReflectiveOperationException e) {
throw Throwables.propagate(e);
}
}
@Override
public Optional<SortExpression> getSortChannel()
{
return sortChannel;
}
}
}