/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.bytecode.instruction.JumpInstruction;
import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.operator.JoinProbe;
import com.facebook.presto.operator.JoinProbeFactory;
import com.facebook.presto.operator.LookupJoinOperator;
import com.facebook.presto.operator.LookupJoinOperatorFactory;
import com.facebook.presto.operator.LookupJoinOperators.JoinType;
import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.LookupSourceFactory;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.operator.SimpleJoinProbe;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.CompilerUtils.defineClass;
import static com.facebook.presto.bytecode.CompilerUtils.makeClassName;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantLong;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance;
import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType;
import static com.google.common.collect.ImmutableList.toImmutableList;
public class JoinProbeCompiler
{
private final LoadingCache<JoinOperatorCacheKey, HashJoinOperatorFactoryFactory> joinProbeFactories = CacheBuilder.newBuilder()
.recordStats()
.maximumSize(1000)
.build(new CacheLoader<JoinOperatorCacheKey, HashJoinOperatorFactoryFactory>()
{
@Override
public HashJoinOperatorFactoryFactory load(JoinOperatorCacheKey key)
throws Exception
{
return internalCompileJoinOperatorFactory(key.getTypes(), key.getProbeOutputChannels(), key.getProbeChannels(), key.getProbeHashChannel());
}
});
@Managed
@Nested
public CacheStatsMBean getJoinProbeFactoriesStats()
{
return new CacheStatsMBean(joinProbeFactories);
}
public OperatorFactory compileJoinOperatorFactory(int operatorId,
PlanNodeId planNodeId,
LookupSourceFactory lookupSourceFactory,
List<? extends Type> probeTypes,
List<Integer> probeJoinChannel,
Optional<Integer> probeHashChannel,
List<Integer> probeOutputChannels,
JoinType joinType)
{
try {
List<Type> probeOutputChannelTypes = probeOutputChannels.stream()
.map(probeTypes::get)
.collect(toImmutableList());
HashJoinOperatorFactoryFactory operatorFactoryFactory = joinProbeFactories.get(new JoinOperatorCacheKey(
probeTypes,
probeOutputChannels,
probeJoinChannel,
probeHashChannel,
joinType));
return operatorFactoryFactory.createHashJoinOperatorFactory(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeOutputChannelTypes, joinType);
}
catch (ExecutionException | UncheckedExecutionException | ExecutionError e) {
throw Throwables.propagate(e.getCause());
}
}
@VisibleForTesting
public HashJoinOperatorFactoryFactory internalCompileJoinOperatorFactory(List<Type> types, List<Integer> probeOutputChannels, List<Integer> probeJoinChannel, Optional<Integer> probeHashChannel)
{
Class<? extends JoinProbe> joinProbeClass = compileJoinProbe(types, probeOutputChannels, probeJoinChannel, probeHashChannel);
ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName("JoinProbeFactory"),
type(Object.class),
type(JoinProbeFactory.class));
classDefinition.declareDefaultConstructor(a(PUBLIC));
Parameter lookupSource = arg("lookupSource", LookupSource.class);
Parameter page = arg("page", Page.class);
MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "createJoinProbe", type(JoinProbe.class), lookupSource, page);
method.getBody()
.newObject(joinProbeClass)
.dup()
.append(lookupSource)
.append(page)
.invokeConstructor(joinProbeClass, LookupSource.class, Page.class)
.retObject();
DynamicClassLoader classLoader = new DynamicClassLoader(joinProbeClass.getClassLoader());
JoinProbeFactory joinProbeFactory;
if (probeJoinChannel.isEmpty()) {
// see comment in PagesIndex#createLookupSource
joinProbeFactory = new SimpleJoinProbe.SimpleJoinProbeFactory(types, probeOutputChannels, probeJoinChannel, probeHashChannel);
}
else {
Class<? extends JoinProbeFactory> joinProbeFactoryClass = defineClass(classDefinition, JoinProbeFactory.class, classLoader);
try {
joinProbeFactory = joinProbeFactoryClass.newInstance();
}
catch (Exception e) {
throw Throwables.propagate(e);
}
}
Class<? extends OperatorFactory> operatorFactoryClass = IsolatedClass.isolateClass(
classLoader,
OperatorFactory.class,
LookupJoinOperatorFactory.class,
LookupJoinOperator.class);
return new HashJoinOperatorFactoryFactory(joinProbeFactory, operatorFactoryClass);
}
@VisibleForTesting
public JoinProbeFactory internalCompileJoinProbe(List<Type> types, List<Integer> probeOutputChannels, List<Integer> probeChannels, Optional<Integer> probeHashChannel)
{
return new ReflectionJoinProbeFactory(compileJoinProbe(types, probeOutputChannels, probeChannels, probeHashChannel));
}
private Class<? extends JoinProbe> compileJoinProbe(List<Type> types, List<Integer> probeOutputChannels, List<Integer> probeChannels, Optional<Integer> probeHashChannel)
{
CallSiteBinder callSiteBinder = new CallSiteBinder();
ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName("JoinProbe"),
type(Object.class),
type(JoinProbe.class));
// declare fields
FieldDefinition lookupSourceField = classDefinition.declareField(a(PRIVATE, FINAL), "lookupSource", LookupSource.class);
FieldDefinition positionCountField = classDefinition.declareField(a(PRIVATE, FINAL), "positionCount", int.class);
List<FieldDefinition> blockFields = new ArrayList<>();
for (int i = 0; i < types.size(); i++) {
FieldDefinition channelField = classDefinition.declareField(a(PRIVATE, FINAL), "block_" + i, Block.class);
blockFields.add(channelField);
}
List<FieldDefinition> probeBlockFields = new ArrayList<>();
for (int i = 0; i < probeChannels.size(); i++) {
FieldDefinition channelField = classDefinition.declareField(a(PRIVATE, FINAL), "probeBlock_" + i, Block.class);
probeBlockFields.add(channelField);
}
FieldDefinition probeBlocksArrayField = classDefinition.declareField(a(PRIVATE, FINAL), "probeBlocks", Block[].class);
FieldDefinition probePageField = classDefinition.declareField(a(PRIVATE, FINAL), "probePage", Page.class);
FieldDefinition pageField = classDefinition.declareField(a(PRIVATE, FINAL), "page", Page.class);
FieldDefinition positionField = classDefinition.declareField(a(PRIVATE), "position", int.class);
FieldDefinition probeHashBlockField = classDefinition.declareField(a(PRIVATE, FINAL), "probeHashBlock", Block.class);
generateConstructor(classDefinition, probeChannels, probeHashChannel, lookupSourceField, blockFields, probeBlockFields, probeBlocksArrayField, probePageField, pageField, probeHashBlockField, positionField, positionCountField);
generateGetChannelCountMethod(classDefinition, probeOutputChannels.size());
generateAppendToMethod(classDefinition, callSiteBinder, types, probeOutputChannels, blockFields, positionField);
generateAdvanceNextPosition(classDefinition, positionField, positionCountField);
generateGetCurrentJoinPosition(classDefinition, callSiteBinder, lookupSourceField, probePageField, pageField, probeHashChannel, probeHashBlockField, positionField);
generateCurrentRowContainsNull(classDefinition, probeBlockFields, positionField);
generateGetPosition(classDefinition, positionField);
generateGetPage(classDefinition, pageField);
return defineClass(classDefinition, JoinProbe.class, callSiteBinder.getBindings(), getClass().getClassLoader());
}
private static void generateConstructor(ClassDefinition classDefinition,
List<Integer> probeChannels,
Optional<Integer> probeHashChannel,
FieldDefinition lookupSourceField,
List<FieldDefinition> blockFields,
List<FieldDefinition> probeChannelFields,
FieldDefinition probeBlocksArrayField,
FieldDefinition probePageField,
FieldDefinition pageField,
FieldDefinition probeHashBlockField,
FieldDefinition positionField,
FieldDefinition positionCountField)
{
Parameter lookupSource = arg("lookupSource", LookupSource.class);
Parameter page = arg("page", Page.class);
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), lookupSource, page);
Variable thisVariable = constructorDefinition.getThis();
BytecodeBlock constructor = constructorDefinition
.getBody()
.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);
constructor.comment("this.lookupSource = lookupSource;")
.append(thisVariable.setField(lookupSourceField, lookupSource));
constructor.comment("this.positionCount = page.getPositionCount();")
.append(thisVariable.setField(positionCountField, page.invoke("getPositionCount", int.class)));
constructor.comment("Set block fields");
for (int index = 0; index < blockFields.size(); index++) {
constructor.append(thisVariable.setField(
blockFields.get(index),
page.invoke("getBlock", Block.class, constantInt(index))));
}
constructor.comment("Set probe channel fields");
for (int index = 0; index < probeChannelFields.size(); index++) {
constructor.append(thisVariable.setField(
probeChannelFields.get(index),
thisVariable.getField(blockFields.get(probeChannels.get(index)))));
}
constructor.comment("this.probeBlocks = new Block[<probeChannelCount>];");
constructor
.append(thisVariable)
.push(probeChannelFields.size())
.newArray(Block.class)
.putField(probeBlocksArrayField);
for (int index = 0; index < probeChannelFields.size(); index++) {
constructor
.append(thisVariable)
.getField(probeBlocksArrayField)
.push(index)
.append(thisVariable)
.getField(probeChannelFields.get(index))
.putObjectArrayElement();
}
constructor.comment("this.page = page")
.append(thisVariable.setField(pageField, page));
constructor.comment("this.probePage = new Page(probeBlocks)")
.append(thisVariable.setField(probePageField, newInstance(Page.class, thisVariable.getField(probeBlocksArrayField))));
if (probeHashChannel.isPresent()) {
Integer index = probeHashChannel.get();
constructor.comment("this.probeHashBlock = blocks[hashChannel.get()]")
.append(thisVariable.setField(
probeHashBlockField,
thisVariable.getField(blockFields.get(index))));
}
constructor.comment("this.position = -1;")
.append(thisVariable.setField(positionField, constantInt(-1)));
constructor.ret();
}
private static void generateGetChannelCountMethod(ClassDefinition classDefinition, int channelCount)
{
classDefinition.declareMethod(
a(PUBLIC),
"getOutputChannelCount",
type(int.class))
.getBody()
.push(channelCount)
.retInt();
}
private static void generateAppendToMethod(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
List<Type> types,
List<Integer> probeOutputChannels,
List<FieldDefinition> blockFields,
FieldDefinition positionField)
{
Parameter pageBuilder = arg("pageBuilder", PageBuilder.class);
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"appendTo",
type(void.class),
pageBuilder);
Variable thisVariable = method.getThis();
int pageBuilderOutputChannel = 0;
for (int outputChannel : probeOutputChannels) {
Type type = types.get(outputChannel);
method.getBody()
.comment("%s.appendTo(block_%s, position, pageBuilder.getBlockBuilder(%s));", type.getClass(), outputChannel, pageBuilderOutputChannel)
.append(constantType(callSiteBinder, type).invoke("appendTo", void.class,
thisVariable.getField(blockFields.get(outputChannel)),
thisVariable.getField(positionField),
pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(pageBuilderOutputChannel++))));
}
method.getBody()
.ret();
}
private static void generateAdvanceNextPosition(ClassDefinition classDefinition, FieldDefinition positionField, FieldDefinition positionCountField)
{
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"advanceNextPosition",
type(boolean.class));
Variable thisVariable = method.getThis();
method.getBody()
.comment("this.position = this.position + 1;")
.append(thisVariable)
.append(thisVariable)
.getField(positionField)
.push(1)
.intAdd()
.putField(positionField);
LabelNode lessThan = new LabelNode("lessThan");
LabelNode end = new LabelNode("end");
method.getBody()
.comment("return position < positionCount;")
.append(thisVariable)
.getField(positionField)
.append(thisVariable)
.getField(positionCountField)
.append(JumpInstruction.jumpIfIntLessThan(lessThan))
.push(false)
.gotoLabel(end)
.visitLabel(lessThan)
.push(true)
.visitLabel(end)
.retBoolean();
}
private static void generateGetCurrentJoinPosition(ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
FieldDefinition lookupSourceField,
FieldDefinition probePageField,
FieldDefinition pageField,
Optional<Integer> probeHashChannel,
FieldDefinition probeHashBlockField,
FieldDefinition positionField)
{
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"getCurrentJoinPosition",
type(long.class));
Variable thisVariable = method.getThis();
BytecodeBlock body = method.getBody()
.append(new IfStatement()
.condition(thisVariable.invoke("currentRowContainsNull", boolean.class))
.ifTrue(constantLong(-1).ret()));
BytecodeExpression position = thisVariable.getField(positionField);
BytecodeExpression hashChannelsPage = thisVariable.getField(probePageField);
BytecodeExpression allChannelsPage = thisVariable.getField(pageField);
BytecodeExpression probeHashBlock = thisVariable.getField(probeHashBlockField);
if (probeHashChannel.isPresent()) {
body.append(thisVariable.getField(lookupSourceField).invoke("getJoinPosition", long.class,
position,
hashChannelsPage,
allChannelsPage,
constantType(callSiteBinder, BigintType.BIGINT).invoke("getLong",
long.class,
probeHashBlock,
position)))
.retLong();
}
else {
body.append(thisVariable.getField(lookupSourceField).invoke("getJoinPosition", long.class, position, hashChannelsPage, allChannelsPage)).retLong();
}
}
private static void generateCurrentRowContainsNull(ClassDefinition classDefinition, List<FieldDefinition> probeBlockFields, FieldDefinition positionField)
{
MethodDefinition method = classDefinition.declareMethod(
a(PRIVATE),
"currentRowContainsNull",
type(boolean.class));
Variable thisVariable = method.getThis();
for (FieldDefinition probeBlockField : probeBlockFields) {
LabelNode checkNextField = new LabelNode("checkNextField");
method.getBody()
.append(thisVariable.getField(probeBlockField).invoke("isNull", boolean.class, thisVariable.getField(positionField)))
.ifFalseGoto(checkNextField)
.push(true)
.retBoolean()
.visitLabel(checkNextField);
}
method.getBody()
.push(false)
.retInt();
}
private static void generateGetPosition(ClassDefinition classDefinition, FieldDefinition positionField)
{
// dummy implementation for now
// compiled class is used only in usecase case when result of this method is ignored.
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"getPosition",
type(int.class));
Variable thisVariable = method.getThis();
method.getBody()
.append(thisVariable.getField(positionField))
.retInt();
}
private static void generateGetPage(ClassDefinition classDefinition, FieldDefinition pageField)
{
// dummy implementation for now
// compiled class is used only in usecase case when result of this method is ignored.
MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC),
"getPage",
type(Page.class));
Variable thisVariable = method.getThis();
method.getBody()
.append(thisVariable.getField(pageField))
.ret(Page.class);
}
public static class ReflectionJoinProbeFactory
implements JoinProbeFactory
{
private final Constructor<? extends JoinProbe> constructor;
public ReflectionJoinProbeFactory(Class<? extends JoinProbe> joinProbeClass)
{
try {
constructor = joinProbeClass.getConstructor(LookupSource.class, Page.class);
}
catch (NoSuchMethodException e) {
throw Throwables.propagate(e);
}
}
@Override
public JoinProbe createJoinProbe(LookupSource lookupSource, Page page)
{
try {
return constructor.newInstance(lookupSource, page);
}
catch (Exception e) {
throw Throwables.propagate(e);
}
}
}
private static final class JoinOperatorCacheKey
{
private final List<Type> types;
private final List<Integer> probeOutputChannels;
private final List<Integer> probeChannels;
private final JoinType joinType;
private final Optional<Integer> probeHashChannel;
private JoinOperatorCacheKey(List<? extends Type> types,
List<Integer> probeOutputChannels,
List<Integer> probeChannels,
Optional<Integer> probeHashChannel,
JoinType joinType)
{
this.probeHashChannel = probeHashChannel;
this.types = ImmutableList.copyOf(types);
this.probeOutputChannels = ImmutableList.copyOf(probeOutputChannels);
this.probeChannels = ImmutableList.copyOf(probeChannels);
this.joinType = joinType;
}
private List<Type> getTypes()
{
return types;
}
private List<Integer> getProbeOutputChannels()
{
return probeOutputChannels;
}
private List<Integer> getProbeChannels()
{
return probeChannels;
}
private Optional<Integer> getProbeHashChannel()
{
return probeHashChannel;
}
@Override
public int hashCode()
{
return Objects.hash(types, probeOutputChannels, probeChannels, joinType);
}
@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
}
if (!(obj instanceof JoinOperatorCacheKey)) {
return false;
}
JoinOperatorCacheKey other = (JoinOperatorCacheKey) obj;
return Objects.equals(this.types, other.types) &&
Objects.equals(this.probeOutputChannels, other.probeOutputChannels) &&
Objects.equals(this.probeChannels, other.probeChannels) &&
Objects.equals(this.probeHashChannel, other.probeHashChannel) &&
Objects.equals(this.joinType, other.joinType);
}
}
private static class HashJoinOperatorFactoryFactory
{
private final JoinProbeFactory joinProbeFactory;
private final Constructor<? extends OperatorFactory> constructor;
private HashJoinOperatorFactoryFactory(JoinProbeFactory joinProbeFactory, Class<? extends OperatorFactory> operatorFactoryClass)
{
this.joinProbeFactory = joinProbeFactory;
try {
constructor = operatorFactoryClass.getConstructor(int.class, PlanNodeId.class, LookupSourceFactory.class, List.class, List.class, JoinType.class, JoinProbeFactory.class);
}
catch (NoSuchMethodException e) {
throw Throwables.propagate(e);
}
}
public OperatorFactory createHashJoinOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
LookupSourceFactory lookupSourceFactory,
List<? extends Type> probeTypes,
List<? extends Type> probeOutputTypes,
JoinType joinType)
{
try {
return constructor.newInstance(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeOutputTypes, joinType, joinProbeFactory);
}
catch (Exception e) {
throw Throwables.propagate(e);
}
}
}
public static void checkState(boolean left, boolean right)
{
if (left != right) {
throw new IllegalStateException();
}
}
}