/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.gen; import com.facebook.presto.SequencePageBuilder; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.operator.JoinProbe; import com.facebook.presto.operator.JoinProbeFactory; import com.facebook.presto.operator.LookupSource; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinCompiler.LookupSourceSupplierFactory; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import it.unimi.dsi.fastutil.longs.LongArrayList; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutorService; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestJoinProbeCompiler { private static final JoinCompiler joinCompiler = new JoinCompiler(); private ExecutorService executor; private TaskContext taskContext; @BeforeMethod public void setUp() { executor = newCachedThreadPool(daemonThreadsNamed("test-%s")); taskContext = createTaskContext(executor, TEST_SESSION); } @AfterMethod public void tearDown() { executor.shutdownNow(); } @DataProvider(name = "hashEnabledValues") public static Object[][] hashEnabledValuesProvider() { return new Object[][] {{true}, {false}}; } @Test(dataProvider = "hashEnabledValues") public void testSingleChannel(boolean hashEnabled) throws Exception { taskContext.addPipelineContext(0, true, true).addDriverContext(); ImmutableList<Type> types = ImmutableList.of(VARCHAR, DOUBLE); ImmutableList<Type> outputTypes = ImmutableList.of(VARCHAR); List<Integer> outputChannels = ImmutableList.of(0); LookupSourceSupplierFactory lookupSourceSupplierFactory = joinCompiler.compileLookupSourceFactory(types, Ints.asList(0), Optional.empty()); // crate hash strategy with a single channel blocks -- make sure there is some overlap in values List<Block> varcharChannel = ImmutableList.of( BlockAssertions.createStringSequenceBlock(10, 20), BlockAssertions.createStringSequenceBlock(20, 30), BlockAssertions.createStringSequenceBlock(15, 25)); List<Block> extraUnusedDoubleChannel = ImmutableList.of( BlockAssertions.createDoubleSequenceBlock(10, 20), BlockAssertions.createDoubleSequenceBlock(20, 30), BlockAssertions.createDoubleSequenceBlock(15, 25)); LongArrayList addresses = new LongArrayList(); for (int blockIndex = 0; blockIndex < varcharChannel.size(); blockIndex++) { Block block = varcharChannel.get(blockIndex); for (int positionIndex = 0; positionIndex < block.getPositionCount(); positionIndex++) { addresses.add(encodeSyntheticAddress(blockIndex, positionIndex)); } } Optional<Integer> hashChannel = Optional.empty(); List<List<Block>> channels = ImmutableList.of(varcharChannel, extraUnusedDoubleChannel); if (hashEnabled) { ImmutableList.Builder<Block> hashChannelBuilder = ImmutableList.builder(); for (Block block : varcharChannel) { hashChannelBuilder.add(TypeUtils.getHashBlock(ImmutableList.<Type>of(VARCHAR), block)); } types = ImmutableList.of(VARCHAR, DOUBLE, BigintType.BIGINT); hashChannel = Optional.of(2); channels = ImmutableList.of(varcharChannel, extraUnusedDoubleChannel, hashChannelBuilder.build()); outputChannels = ImmutableList.of(0, 2); outputTypes = ImmutableList.of(VARCHAR, BigintType.BIGINT); } LookupSource lookupSource = lookupSourceSupplierFactory.createLookupSourceSupplier( taskContext.getSession(), addresses, channels, hashChannel, Optional.empty()) .get(); JoinProbeCompiler joinProbeCompiler = new JoinProbeCompiler(); JoinProbeFactory probeFactory = joinProbeCompiler.internalCompileJoinProbe( types, outputChannels, Ints.asList(0), hashChannel); Page page = SequencePageBuilder.createSequencePage(types, 10, 10, 10); Page outputPage = new Page(page.getBlock(0)); if (hashEnabled) { page = new Page(page.getBlock(0), page.getBlock(1), TypeUtils.getHashBlock(ImmutableList.of(VARCHAR), page.getBlock(0))); outputPage = new Page(page.getBlock(0), page.getBlock(2)); } JoinProbe joinProbe = probeFactory.createJoinProbe(lookupSource, page); // verify channel count assertEquals(joinProbe.getOutputChannelCount(), outputChannels.size()); PageBuilder pageBuilder = new PageBuilder(outputTypes); for (int position = 0; position < page.getPositionCount(); position++) { assertTrue(joinProbe.advanceNextPosition()); pageBuilder.declarePosition(); joinProbe.appendTo(pageBuilder); assertEquals(joinProbe.getCurrentJoinPosition(), lookupSource.getJoinPosition(position, page, page)); } assertFalse(joinProbe.advanceNextPosition()); assertPageEquals(outputTypes, pageBuilder.build(), outputPage); } }