/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.presto.RowPagesBuilder;
import com.facebook.presto.operator.HashBuilderOperator.HashBuilderOperatorFactory;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinProbeCompiler;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.testing.TestingTaskContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import io.airlift.units.DataSize;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static java.lang.String.format;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.openjdk.jmh.annotations.Mode.AverageTime;
import static org.openjdk.jmh.annotations.Scope.Thread;
@SuppressWarnings("MethodMayBeStatic")
@State(Thread)
@OutputTimeUnit(MILLISECONDS)
@BenchmarkMode(AverageTime)
@Fork(3)
@Warmup(iterations = 5)
@Measurement(iterations = 20)
public class BenchmarkHashBuildAndJoinOperators
{
private static final int HASH_BUILD_OPERATOR_ID = 1;
private static final int HASH_JOIN_OPERATOR_ID = 2;
private static final PlanNodeId TEST_PLAN_NODE_ID = new PlanNodeId("test");
private static final LookupJoinOperators LOOKUP_JOIN_OPERATORS = new LookupJoinOperators(new JoinProbeCompiler());
@State(Thread)
public static class BuildContext
{
protected static final int ROWS_PER_PAGE = 1024;
protected static final int BUILD_ROWS_NUMBER = 700_000;
@Param({"varchar", "bigint", "all"})
protected String hashColumns;
@Param({"false", "true"})
protected boolean buildHashEnabled;
@Param({"1", "5"})
protected int buildRowsRepetition;
protected ExecutorService executor;
protected List<Page> buildPages;
protected Optional<Integer> hashChannel;
protected List<Type> types;
protected List<Integer> hashChannels;
@Setup
public void setup()
{
switch (hashColumns) {
case "varchar":
hashChannels = Ints.asList(0);
break;
case "bigint":
hashChannels = Ints.asList(1);
break;
case "all":
hashChannels = Ints.asList(0, 1, 2);
break;
default:
throw new UnsupportedOperationException(format("Unknown hashColumns value [%s]", hashColumns));
}
executor = newCachedThreadPool(daemonThreadsNamed("test-%s"));
initializeBuildPages();
}
public TaskContext createTaskContext()
{
return TestingTaskContext.createTaskContext(executor, TEST_SESSION, new DataSize(2, GIGABYTE));
}
public Optional<Integer> getHashChannel()
{
return hashChannel;
}
public List<Integer> getHashChannels()
{
return hashChannels;
}
public List<Type> getTypes()
{
return types;
}
public List<Page> getBuildPages()
{
return buildPages;
}
protected void initializeBuildPages()
{
RowPagesBuilder buildPagesBuilder = rowPagesBuilder(buildHashEnabled, hashChannels, ImmutableList.of(VARCHAR, BIGINT, BIGINT));
int maxValue = BUILD_ROWS_NUMBER / buildRowsRepetition + 40;
int rows = 0;
while (rows < BUILD_ROWS_NUMBER) {
int newRows = Math.min(BUILD_ROWS_NUMBER - rows, ROWS_PER_PAGE);
buildPagesBuilder.addSequencePage(newRows, (rows + 20) % maxValue, (rows + 30) % maxValue, (rows + 40) % maxValue);
buildPagesBuilder.pageBreak();
rows += newRows;
}
types = buildPagesBuilder.getTypes();
buildPages = buildPagesBuilder.build();
hashChannel = buildPagesBuilder.getHashChannel();
}
}
@State(Thread)
public static class JoinContext
extends BuildContext
{
protected static final int PROBE_ROWS_NUMBER = 700_000;
@Param({"0.1", "1", "2"})
protected double matchRate;
@Param({"bigint", "all"})
protected String outputColumns;
protected List<Page> probePages;
protected List<Integer> outputChannels;
protected LookupSourceFactory lookupSourceFactory;
@Override
@Setup
public void setup()
{
super.setup();
switch (outputColumns) {
case "varchar":
outputChannels = Ints.asList(0);
break;
case "bigint":
outputChannels = Ints.asList(1);
break;
case "all":
outputChannels = Ints.asList(0, 1, 2);
break;
default:
throw new UnsupportedOperationException(format("Unknown outputColumns value [%s]", hashColumns));
}
lookupSourceFactory = new BenchmarkHashBuildAndJoinOperators().benchmarkBuildHash(this, outputChannels);
initializeProbePages();
}
public LookupSourceFactory getLookupSourceFactory()
{
return lookupSourceFactory;
}
public List<Page> getProbePages()
{
return probePages;
}
public List<Integer> getOutputChannels()
{
return outputChannels;
}
protected void initializeProbePages()
{
RowPagesBuilder probePagesBuilder = rowPagesBuilder(buildHashEnabled, hashChannels, ImmutableList.of(VARCHAR, BIGINT, BIGINT));
Random random = new Random(42);
int remainingRows = PROBE_ROWS_NUMBER;
int rowsInPage = 0;
while (remainingRows > 0) {
double roll = random.nextDouble();
int columnA = 20 + remainingRows;
int columnB = 30 + remainingRows;
int columnC = 40 + remainingRows;
int rowsCount = 1;
if (matchRate < 1) {
// each row has matchRate chance to join
if (roll > matchRate) {
// generate not matched row
columnA *= -1;
columnB *= -1;
columnC *= -1;
}
}
else if (matchRate > 1) {
// each row has will be repeated between one and 2*matchRate times
roll = roll * 2 * matchRate + 1;
// example for matchRate == 2:
// roll is within [0, 5) range
// rowsCount is within [0, 4] range, where each value has same probability
// so expected rowsCount is 2
rowsCount = (int) Math.floor(roll);
}
for (int i = 0; i < rowsCount; i++) {
if (rowsInPage >= ROWS_PER_PAGE) {
probePagesBuilder.pageBreak();
rowsInPage = 0;
}
probePagesBuilder.row(format("%d", columnA), columnB, columnC);
--remainingRows;
rowsInPage++;
}
}
probePages = probePagesBuilder.build();
}
}
@Benchmark
public LookupSourceFactory benchmarkBuildHash(BuildContext buildContext)
{
return benchmarkBuildHash(buildContext, ImmutableList.of(0, 1, 2));
}
private LookupSourceFactory benchmarkBuildHash(BuildContext buildContext, List<Integer> outputChannels)
{
DriverContext driverContext = buildContext.createTaskContext().addPipelineContext(0, true, true).addDriverContext();
HashBuilderOperatorFactory hashBuilderOperatorFactory = new HashBuilderOperatorFactory(
HASH_BUILD_OPERATOR_ID,
TEST_PLAN_NODE_ID,
buildContext.getTypes(),
outputChannels,
ImmutableMap.of(),
buildContext.getHashChannels(),
buildContext.getHashChannel(),
false,
Optional.empty(),
10_000,
1,
new PagesIndex.TestingFactory());
Operator operator = hashBuilderOperatorFactory.createOperator(driverContext);
for (Page page : buildContext.getBuildPages()) {
operator.addInput(page);
}
operator.finish();
if (!hashBuilderOperatorFactory.getLookupSourceFactory().createLookupSource().isDone()) {
throw new AssertionError("Expected lookup source to be done");
}
return hashBuilderOperatorFactory.getLookupSourceFactory();
}
@Benchmark
public List<Page> benchmarkJoinHash(JoinContext joinContext)
{
LookupSourceFactory lookupSourceFactory = joinContext.getLookupSourceFactory();
OperatorFactory joinOperatorFactory = LOOKUP_JOIN_OPERATORS.innerJoin(
HASH_JOIN_OPERATOR_ID,
TEST_PLAN_NODE_ID,
lookupSourceFactory,
joinContext.getTypes(),
joinContext.getHashChannels(),
joinContext.getHashChannel(),
Optional.of(joinContext.getOutputChannels()));
DriverContext driverContext = joinContext.createTaskContext().addPipelineContext(0, true, true).addDriverContext();
Operator joinOperator = joinOperatorFactory.createOperator(driverContext);
Iterator<Page> input = joinContext.getProbePages().iterator();
ImmutableList.Builder<Page> outputPages = ImmutableList.builder();
boolean finishing = false;
for (int loops = 0; !joinOperator.isFinished() && loops < 1_000_000; loops++) {
if (joinOperator.needsInput()) {
if (input.hasNext()) {
Page inputPage = input.next();
joinOperator.addInput(inputPage);
}
else if (!finishing) {
joinOperator.finish();
finishing = true;
}
}
Page outputPage = joinOperator.getOutput();
if (outputPage != null) {
outputPages.add(outputPage);
}
}
return outputPages.build();
}
public static void main(String[] args)
throws RunnerException
{
Options options = new OptionsBuilder()
.verbosity(VerboseMode.NORMAL)
.include(".*" + BenchmarkHashBuildAndJoinOperators.class.getSimpleName() + ".*")
.build();
new Runner(options).run();
}
}