/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.gen; import com.facebook.presto.Session; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.OpCode; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.bytecode.expression.BytecodeExpression; import com.facebook.presto.bytecode.instruction.LabelNode; import com.facebook.presto.operator.JoinHash; import com.facebook.presto.operator.JoinHashSupplier; import com.facebook.presto.operator.LookupSourceSupplier; import com.facebook.presto.operator.PagesHash; import com.facebook.presto.operator.PagesHashStrategy; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ExecutionError; import com.google.common.util.concurrent.UncheckedExecutionException; import it.unimi.dsi.fastutil.longs.LongArrayList; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantLong; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.notEqual; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class JoinCompiler { private final LoadingCache<CacheKey, LookupSourceSupplierFactory> lookupSourceFactories = CacheBuilder.newBuilder() .recordStats() .maximumSize(1000) .build(new CacheLoader<CacheKey, LookupSourceSupplierFactory>() { @Override public LookupSourceSupplierFactory load(CacheKey key) throws Exception { return internalCompileLookupSourceFactory(key.getTypes(), key.getOutputChannels(), key.getJoinChannels(), key.getSortChannel()); } }); private final LoadingCache<CacheKey, Class<? extends PagesHashStrategy>> hashStrategies = CacheBuilder.newBuilder() .recordStats() .maximumSize(1000) .build(new CacheLoader<CacheKey, Class<? extends PagesHashStrategy>>() { @Override public Class<? extends PagesHashStrategy> load(CacheKey key) throws Exception { return internalCompileHashStrategy(key.getTypes(), key.getOutputChannels(), key.getJoinChannels(), key.getSortChannel()); } }); public LookupSourceSupplierFactory compileLookupSourceFactory(List<? extends Type> types, List<Integer> joinChannels, Optional<SortExpression> sortChannel) { return compileLookupSourceFactory(types, joinChannels, sortChannel, Optional.empty()); } @Managed @Nested public CacheStatsMBean getLookupSourceStats() { return new CacheStatsMBean(lookupSourceFactories); } @Managed @Nested public CacheStatsMBean getHashStrategiesStats() { return new CacheStatsMBean(hashStrategies); } public LookupSourceSupplierFactory compileLookupSourceFactory(List<? extends Type> types, List<Integer> joinChannels, Optional<SortExpression> sortChannel, Optional<List<Integer>> outputChannels) { try { return lookupSourceFactories.get(new CacheKey( types, outputChannels.orElse(rangeList(types.size())), joinChannels, sortChannel)); } catch (ExecutionException | UncheckedExecutionException | ExecutionError e) { throw Throwables.propagate(e.getCause()); } } public PagesHashStrategyFactory compilePagesHashStrategyFactory(List<Type> types, List<Integer> joinChannels) { return compilePagesHashStrategyFactory(types, joinChannels, Optional.empty()); } public PagesHashStrategyFactory compilePagesHashStrategyFactory(List<Type> types, List<Integer> joinChannels, Optional<List<Integer>> outputChannels) { requireNonNull(types, "types is null"); requireNonNull(joinChannels, "joinChannels is null"); requireNonNull(outputChannels, "outputChannels is null"); try { return new PagesHashStrategyFactory(hashStrategies.get(new CacheKey( types, outputChannels.orElse(rangeList(types.size())), joinChannels, Optional.empty()))); } catch (ExecutionException | UncheckedExecutionException | ExecutionError e) { throw Throwables.propagate(e.getCause()); } } private List<Integer> rangeList(int endExclusive) { return IntStream.range(0, endExclusive) .boxed() .collect(toImmutableList()); } private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List<Type> types, List<Integer> outputChannels, List<Integer> joinChannels, Optional<SortExpression> sortChannel) { Class<? extends PagesHashStrategy> pagesHashStrategyClass = internalCompileHashStrategy(types, outputChannels, joinChannels, sortChannel); Class<? extends LookupSourceSupplier> joinHashSupplierClass = IsolatedClass.isolateClass( new DynamicClassLoader(getClass().getClassLoader()), LookupSourceSupplier.class, JoinHashSupplier.class, JoinHash.class, PagesHash.class); return new LookupSourceSupplierFactory(joinHashSupplierClass, new PagesHashStrategyFactory(pagesHashStrategyClass)); } private Class<? extends PagesHashStrategy> internalCompileHashStrategy(List<Type> types, List<Integer> outputChannels, List<Integer> joinChannels, Optional<SortExpression> sortChannel) { CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName("PagesHashStrategy"), type(Object.class), type(PagesHashStrategy.class)); FieldDefinition sizeField = classDefinition.declareField(a(PRIVATE, FINAL), "size", type(long.class)); List<FieldDefinition> channelFields = new ArrayList<>(); for (int i = 0; i < types.size(); i++) { FieldDefinition channelField = classDefinition.declareField(a(PRIVATE, FINAL), "channel_" + i, type(List.class, Block.class)); channelFields.add(channelField); } List<Type> joinChannelTypes = new ArrayList<>(); List<FieldDefinition> joinChannelFields = new ArrayList<>(); for (int i = 0; i < joinChannels.size(); i++) { joinChannelTypes.add(types.get(joinChannels.get(i))); FieldDefinition channelField = classDefinition.declareField(a(PRIVATE, FINAL), "joinChannel_" + i, type(List.class, Block.class)); joinChannelFields.add(channelField); } FieldDefinition hashChannelField = classDefinition.declareField(a(PRIVATE, FINAL), "hashChannel", type(List.class, Block.class)); generateConstructor(classDefinition, joinChannels, sizeField, channelFields, joinChannelFields, hashChannelField); generateGetChannelCountMethod(classDefinition, outputChannels.size()); generateGetSizeInBytesMethod(classDefinition, sizeField); generateAppendToMethod(classDefinition, callSiteBinder, types, outputChannels, channelFields); generateHashPositionMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields, hashChannelField); generateHashRowMethod(classDefinition, callSiteBinder, joinChannelTypes); generateRowEqualsRowMethod(classDefinition, callSiteBinder, joinChannelTypes); generatePositionEqualsRowMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields, true); generatePositionEqualsRowMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields, false); generatePositionEqualsRowWithPageMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields); generatePositionEqualsPositionMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields, true); generatePositionEqualsPositionMethod(classDefinition, callSiteBinder, joinChannelTypes, joinChannelFields, false); generateIsPositionNull(classDefinition, joinChannelFields); generateCompareMethod(classDefinition, callSiteBinder, types, channelFields, sortChannel); return defineClass(classDefinition, PagesHashStrategy.class, callSiteBinder.getBindings(), getClass().getClassLoader()); } private static void generateConstructor(ClassDefinition classDefinition, List<Integer> joinChannels, FieldDefinition sizeField, List<FieldDefinition> channelFields, List<FieldDefinition> joinChannelFields, FieldDefinition hashChannelField) { Parameter channels = arg("channels", type(List.class, type(List.class, Block.class))); Parameter hashChannel = arg("hashChannel", type(Optional.class, Integer.class)); MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), channels, hashChannel); Variable thisVariable = constructorDefinition.getThis(); Variable blockIndex = constructorDefinition.getScope().declareVariable(int.class, "blockIndex"); BytecodeBlock constructor = constructorDefinition .getBody() .comment("super();") .append(thisVariable) .invokeConstructor(Object.class); constructor.comment("this.size = 0") .append(thisVariable.setField(sizeField, constantLong(0L))); constructor.comment("Set channel fields"); for (int index = 0; index < channelFields.size(); index++) { BytecodeExpression channel = channels.invoke("get", Object.class, constantInt(index)) .cast(type(List.class, Block.class)); constructor.append(thisVariable.setField(channelFields.get(index), channel)); BytecodeBlock loopBody = new BytecodeBlock(); constructor.comment("for(blockIndex = 0; blockIndex < channel.size(); blockIndex++) { size += channel.get(i).getRetainedSizeInBytes() }") .append(new ForLoop() .initialize(blockIndex.set(constantInt(0))) .condition(new BytecodeBlock() .append(blockIndex) .append(channel.invoke("size", int.class)) .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) .update(new BytecodeBlock().incrementVariable(blockIndex, (byte) 1)) .body(loopBody)); loopBody.append(thisVariable) .append(thisVariable) .getField(sizeField) .append( channel.invoke("get", Object.class, blockIndex) .cast(type(Block.class)) .invoke("getRetainedSizeInBytes", int.class) .cast(long.class)) .longAdd() .putField(sizeField); } constructor.comment("Set join channel fields"); for (int index = 0; index < joinChannelFields.size(); index++) { BytecodeExpression joinChannel = channels.invoke("get", Object.class, constantInt(joinChannels.get(index))) .cast(type(List.class, Block.class)); constructor.append(thisVariable.setField(joinChannelFields.get(index), joinChannel)); } constructor.comment("Set hashChannel"); constructor.append(new IfStatement() .condition(hashChannel.invoke("isPresent", boolean.class)) .ifTrue(thisVariable.setField( hashChannelField, channels.invoke("get", Object.class, hashChannel.invoke("get", Object.class).cast(Integer.class).cast(int.class)))) .ifFalse(thisVariable.setField( hashChannelField, constantNull(hashChannelField.getType())))); constructor.ret(); } private static void generateGetChannelCountMethod(ClassDefinition classDefinition, int outputChannelCount) { classDefinition.declareMethod( a(PUBLIC), "getChannelCount", type(int.class)) .getBody() .push(outputChannelCount) .retInt(); } private static void generateGetSizeInBytesMethod(ClassDefinition classDefinition, FieldDefinition sizeField) { MethodDefinition getSizeInBytesMethod = classDefinition.declareMethod(a(PUBLIC), "getSizeInBytes", type(long.class)); Variable thisVariable = getSizeInBytesMethod.getThis(); getSizeInBytesMethod.getBody() .append(thisVariable.getField(sizeField)) .retLong(); } private static void generateAppendToMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> types, List<Integer> outputChannels, List<FieldDefinition> channelFields) { Parameter blockIndex = arg("blockIndex", int.class); Parameter blockPosition = arg("blockPosition", int.class); Parameter pageBuilder = arg("pageBuilder", PageBuilder.class); Parameter outputChannelOffset = arg("outputChannelOffset", int.class); MethodDefinition appendToMethod = classDefinition.declareMethod(a(PUBLIC), "appendTo", type(void.class), blockIndex, blockPosition, pageBuilder, outputChannelOffset); Variable thisVariable = appendToMethod.getThis(); BytecodeBlock appendToBody = appendToMethod.getBody(); int pageBuilderOutputChannel = 0; for (int outputChannel : outputChannels) { Type type = types.get(outputChannel); BytecodeExpression typeExpression = constantType(callSiteBinder, type); BytecodeExpression block = thisVariable .getField(channelFields.get(outputChannel)) .invoke("get", Object.class, blockIndex) .cast(Block.class); appendToBody .comment("%s.appendTo(channel_%s.get(outputChannel), blockPosition, pageBuilder.getBlockBuilder(outputChannelOffset + %s));", type.getClass(), outputChannel, pageBuilderOutputChannel) .append(typeExpression) .append(block) .append(blockPosition) .append(pageBuilder) .append(outputChannelOffset) .push(pageBuilderOutputChannel++) .append(OpCode.IADD) .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class) .invokeInterface(Type.class, "appendTo", void.class, Block.class, int.class, BlockBuilder.class); } appendToBody.ret(); } private static void generateIsPositionNull(ClassDefinition classDefinition, List<FieldDefinition> joinChannelFields) { Parameter blockIndex = arg("blockIndex", int.class); Parameter blockPosition = arg("blockPosition", int.class); MethodDefinition isPositionNullMethod = classDefinition.declareMethod( a(PUBLIC), "isPositionNull", type(boolean.class), blockIndex, blockPosition); for (FieldDefinition joinChannelField : joinChannelFields) { BytecodeExpression block = isPositionNullMethod .getThis() .getField(joinChannelField) .invoke("get", Object.class, blockIndex) .cast(Block.class); IfStatement ifStatement = new IfStatement(); ifStatement.condition(block.invoke( "isNull", boolean.class, blockPosition )); ifStatement.ifTrue(constantTrue().ret()); isPositionNullMethod.getBody().append(ifStatement); } isPositionNullMethod .getBody() .append(constantFalse().ret()); } private static void generateHashPositionMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes, List<FieldDefinition> joinChannelFields, FieldDefinition hashChannelField) { Parameter blockIndex = arg("blockIndex", int.class); Parameter blockPosition = arg("blockPosition", int.class); MethodDefinition hashPositionMethod = classDefinition.declareMethod( a(PUBLIC), "hashPosition", type(long.class), blockIndex, blockPosition); Variable thisVariable = hashPositionMethod.getThis(); BytecodeExpression hashChannel = thisVariable.getField(hashChannelField); BytecodeExpression bigintType = constantType(callSiteBinder, BigintType.BIGINT); IfStatement ifStatement = new IfStatement(); ifStatement.condition(notEqual(hashChannel, constantNull(hashChannelField.getType()))); ifStatement.ifTrue( bigintType.invoke( "getLong", long.class, hashChannel.invoke("get", Object.class, blockIndex).cast(Block.class), blockPosition) .ret() ); hashPositionMethod .getBody() .append(ifStatement); Variable resultVariable = hashPositionMethod.getScope().declareVariable(long.class, "result"); hashPositionMethod.getBody().push(0L).putVariable(resultVariable); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression block = hashPositionMethod .getThis() .getField(joinChannelFields.get(index)) .invoke("get", Object.class, blockIndex) .cast(Block.class); hashPositionMethod .getBody() .getVariable(resultVariable) .push(31L) .append(OpCode.LMUL) .append(typeHashCode(type, block, blockPosition)) .append(OpCode.LADD) .putVariable(resultVariable); } hashPositionMethod .getBody() .getVariable(resultVariable) .retLong(); } private static void generateHashRowMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes) { Parameter position = arg("position", int.class); Parameter page = arg("blocks", Page.class); MethodDefinition hashRowMethod = classDefinition.declareMethod(a(PUBLIC), "hashRow", type(long.class), position, page); Variable resultVariable = hashRowMethod.getScope().declareVariable(long.class, "result"); hashRowMethod.getBody().push(0L).putVariable(resultVariable); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression block = page.invoke("getBlock", Block.class, constantInt(index)); hashRowMethod .getBody() .getVariable(resultVariable) .push(31L) .append(OpCode.LMUL) .append(typeHashCode(type, block, position)) .append(OpCode.LADD) .putVariable(resultVariable); } hashRowMethod .getBody() .getVariable(resultVariable) .retLong(); } private static BytecodeNode typeHashCode(BytecodeExpression type, BytecodeExpression blockRef, BytecodeExpression blockPosition) { return new IfStatement() .condition(blockRef.invoke("isNull", boolean.class, blockPosition)) .ifTrue(constantLong(0L)) .ifFalse(type.invoke("hash", long.class, blockRef, blockPosition)); } private static void generateRowEqualsRowMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes) { Parameter leftPosition = arg("leftPosition", int.class); Parameter leftPage = arg("leftPage", Page.class); Parameter rightPosition = arg("rightPosition", int.class); Parameter rightPage = arg("rightPage", Page.class); MethodDefinition rowEqualsRowMethod = classDefinition.declareMethod( a(PUBLIC), "rowEqualsRow", type(boolean.class), leftPosition, leftPage, rightPosition, rightPage); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression leftBlock = leftPage.invoke("getBlock", Block.class, constantInt(index)); BytecodeExpression rightBlock = rightPage.invoke("getBlock", Block.class, constantInt(index)); LabelNode checkNextField = new LabelNode("checkNextField"); rowEqualsRowMethod .getBody() .append(typeEquals( type, leftBlock, leftPosition, rightBlock, rightPosition)) .ifTrueGoto(checkNextField) .push(false) .retBoolean() .visitLabel(checkNextField); } rowEqualsRowMethod .getBody() .push(true) .retInt(); } private static void generatePositionEqualsRowMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes, List<FieldDefinition> joinChannelFields, boolean ignoreNulls) { Parameter leftBlockIndex = arg("leftBlockIndex", int.class); Parameter leftBlockPosition = arg("leftBlockPosition", int.class); Parameter rightPosition = arg("rightPosition", int.class); Parameter rightPage = arg("rightPage", Page.class); MethodDefinition positionEqualsRowMethod = classDefinition.declareMethod( a(PUBLIC), ignoreNulls ? "positionEqualsRowIgnoreNulls" : "positionEqualsRow", type(boolean.class), leftBlockIndex, leftBlockPosition, rightPosition, rightPage); Variable thisVariable = positionEqualsRowMethod.getThis(); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression leftBlock = thisVariable .getField(joinChannelFields.get(index)) .invoke("get", Object.class, leftBlockIndex) .cast(Block.class); BytecodeExpression rightBlock = rightPage.invoke("getBlock", Block.class, constantInt(index)); BytecodeNode equalityCondition; if (ignoreNulls) { equalityCondition = typeEqualsIgnoreNulls(type, leftBlock, leftBlockPosition, rightBlock, rightPosition); } else { equalityCondition = typeEquals(type, leftBlock, leftBlockPosition, rightBlock, rightPosition); } LabelNode checkNextField = new LabelNode("checkNextField"); positionEqualsRowMethod .getBody() .append(equalityCondition) .ifTrueGoto(checkNextField) .push(false) .retBoolean() .visitLabel(checkNextField); } positionEqualsRowMethod .getBody() .push(true) .retInt(); } private static void generatePositionEqualsRowWithPageMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes, List<FieldDefinition> joinChannelFields) { Parameter leftBlockIndex = arg("leftBlockIndex", int.class); Parameter leftBlockPosition = arg("leftBlockPosition", int.class); Parameter rightPosition = arg("rightPosition", int.class); Parameter page = arg("page", Page.class); Parameter rightChannels = arg("rightChannels", int[].class); MethodDefinition positionEqualsRowMethod = classDefinition.declareMethod( a(PUBLIC), "positionEqualsRow", type(boolean.class), leftBlockIndex, leftBlockPosition, rightPosition, page, rightChannels); Variable thisVariable = positionEqualsRowMethod.getThis(); BytecodeBlock body = positionEqualsRowMethod.getBody(); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression leftBlock = thisVariable .getField(joinChannelFields.get(index)) .invoke("get", Object.class, leftBlockIndex) .cast(Block.class); BytecodeExpression rightBlock = page.invoke("getBlock", Block.class, rightChannels.getElement(index)); body.append(new IfStatement() .condition(typeEquals(type, leftBlock, leftBlockPosition, rightBlock, rightPosition)) .ifFalse(constantFalse().ret())); } body.append(constantTrue().ret()); } private static void generatePositionEqualsPositionMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> joinChannelTypes, List<FieldDefinition> joinChannelFields, boolean ignoreNulls) { Parameter leftBlockIndex = arg("leftBlockIndex", int.class); Parameter leftBlockPosition = arg("leftBlockPosition", int.class); Parameter rightBlockIndex = arg("rightBlockIndex", int.class); Parameter rightBlockPosition = arg("rightBlockPosition", int.class); MethodDefinition positionEqualsPositionMethod = classDefinition.declareMethod( a(PUBLIC), ignoreNulls ? "positionEqualsPositionIgnoreNulls" : "positionEqualsPosition", type(boolean.class), leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); Variable thisVariable = positionEqualsPositionMethod.getThis(); for (int index = 0; index < joinChannelTypes.size(); index++) { BytecodeExpression type = constantType(callSiteBinder, joinChannelTypes.get(index)); BytecodeExpression leftBlock = thisVariable .getField(joinChannelFields.get(index)) .invoke("get", Object.class, leftBlockIndex) .cast(Block.class); BytecodeExpression rightBlock = thisVariable .getField(joinChannelFields.get(index)) .invoke("get", Object.class, rightBlockIndex) .cast(Block.class); BytecodeNode equalityCondition; if (ignoreNulls) { equalityCondition = typeEqualsIgnoreNulls(type, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); } else { equalityCondition = typeEquals(type, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); } LabelNode checkNextField = new LabelNode("checkNextField"); positionEqualsPositionMethod .getBody() .append(equalityCondition) .ifTrueGoto(checkNextField) .push(false) .retBoolean() .visitLabel(checkNextField); } positionEqualsPositionMethod .getBody() .push(true) .retInt(); } private static void generateCompareMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> types, List<FieldDefinition> channelFields, Optional<SortExpression> sortChannel) { Parameter leftBlockIndex = arg("leftBlockIndex", int.class); Parameter leftBlockPosition = arg("leftBlockPosition", int.class); Parameter rightBlockIndex = arg("rightBlockIndex", int.class); Parameter rightBlockPosition = arg("rightBlockPosition", int.class); MethodDefinition compareMethod = classDefinition.declareMethod( a(PUBLIC), "compare", type(int.class), leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); if (!sortChannel.isPresent()) { compareMethod.getBody() .append(newInstance(UnsupportedOperationException.class)) .throwObject(); return; } Variable thisVariable = compareMethod.getThis(); int index = sortChannel.get().getChannel(); BytecodeExpression type = constantType(callSiteBinder, types.get(index)); BytecodeExpression leftBlock = thisVariable .getField(channelFields.get(index)) .invoke("get", Object.class, leftBlockIndex) .cast(Block.class); BytecodeExpression rightBlock = thisVariable .getField(channelFields.get(index)) .invoke("get", Object.class, rightBlockIndex) .cast(Block.class); BytecodeNode comparison = type.invoke("compareTo", int.class, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition).ret(); compareMethod .getBody() .append(comparison); } private static BytecodeNode typeEquals( BytecodeExpression type, BytecodeExpression leftBlock, BytecodeExpression leftBlockPosition, BytecodeExpression rightBlock, BytecodeExpression rightBlockPosition) { IfStatement ifStatement = new IfStatement(); ifStatement.condition() .append(leftBlock.invoke("isNull", boolean.class, leftBlockPosition)) .append(rightBlock.invoke("isNull", boolean.class, rightBlockPosition)) .append(OpCode.IOR); ifStatement.ifTrue() .append(leftBlock.invoke("isNull", boolean.class, leftBlockPosition)) .append(rightBlock.invoke("isNull", boolean.class, rightBlockPosition)) .append(OpCode.IAND); ifStatement.ifFalse().append(typeEqualsIgnoreNulls(type, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition)); return ifStatement; } private static BytecodeNode typeEqualsIgnoreNulls( BytecodeExpression type, BytecodeExpression leftBlock, BytecodeExpression leftBlockPosition, BytecodeExpression rightBlock, BytecodeExpression rightBlockPosition) { return type.invoke("equalTo", boolean.class, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); } public static class LookupSourceSupplierFactory { private final Constructor<? extends LookupSourceSupplier> constructor; private final PagesHashStrategyFactory pagesHashStrategyFactory; public LookupSourceSupplierFactory(Class<? extends LookupSourceSupplier> joinHashSupplierClass, PagesHashStrategyFactory pagesHashStrategyFactory) { this.pagesHashStrategyFactory = pagesHashStrategyFactory; try { constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); } } public LookupSourceSupplier createLookupSourceSupplier( Session session, LongArrayList addresses, List<List<Block>> channels, Optional<Integer> hashChannel, Optional<JoinFilterFunctionFactory> filterFunctionFactory) { PagesHashStrategy pagesHashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channels, hashChannel); try { return constructor.newInstance(session, pagesHashStrategy, addresses, channels, filterFunctionFactory); } catch (Exception e) { throw Throwables.propagate(e); } } } public static class PagesHashStrategyFactory { private final Constructor<? extends PagesHashStrategy> constructor; public PagesHashStrategyFactory(Class<? extends PagesHashStrategy> pagesHashStrategyClass) { try { constructor = pagesHashStrategyClass.getConstructor(List.class, Optional.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); } } public PagesHashStrategy createPagesHashStrategy(List<? extends List<Block>> channels, Optional<Integer> hashChannel) { try { return constructor.newInstance(channels, hashChannel); } catch (Exception e) { throw Throwables.propagate(e); } } } private static final class CacheKey { private final List<Type> types; private final List<Integer> outputChannels; private final List<Integer> joinChannels; private final Optional<SortExpression> sortChannel; private CacheKey(List<? extends Type> types, List<Integer> outputChannels, List<Integer> joinChannels, Optional<SortExpression> sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); this.joinChannels = ImmutableList.copyOf(requireNonNull(joinChannels, "joinChannels is null")); this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); } private List<Type> getTypes() { return types; } private List<Integer> getOutputChannels() { return outputChannels; } private List<Integer> getJoinChannels() { return joinChannels; } public Optional<SortExpression> getSortChannel() { return sortChannel; } @Override public int hashCode() { return Objects.hash(types, outputChannels, joinChannels, sortChannel); } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (!(obj instanceof CacheKey)) { return false; } CacheKey other = (CacheKey) obj; return Objects.equals(this.types, other.types) && Objects.equals(this.outputChannels, other.outputChannels) && Objects.equals(this.joinChannels, other.joinChannels) && Objects.equals(this.sortChannel, other.sortChannel); } } }