/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.type;
import com.facebook.presto.RowPagesBuilder;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.project.PageProcessor;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.SqlDecimal;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolToInputRewriter;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.collect.ImmutableList;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import org.testng.annotations.Test;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.metadata.FunctionKind.SCALAR;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.operator.scalar.FunctionAssertions.createExpression;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DecimalType.createDecimalType;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.math.BigInteger.ONE;
import static java.math.BigInteger.ZERO;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.openjdk.jmh.annotations.Scope.Thread;
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 3)
@Warmup(iterations = 20, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, timeUnit = TimeUnit.MILLISECONDS)
public class BenchmarkDecimalOperators
{
private static final int PAGE_SIZE = 30000;
private static final DecimalType SHORT_DECIMAL_TYPE = createDecimalType(10, 0);
private static final DecimalType LONG_DECIMAL_TYPE = createDecimalType(20, 0);
private static final SqlParser SQL_PARSER = new SqlParser();
@State(Thread)
public static class CastDoubleToDecimalBenchmarkState
extends BaseState
{
private static final int SCALE = 2;
@Param({"10", "35", "BIGINT"})
private String precision = "10";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
String expression;
if (precision.equals("BIGINT")) {
setDoubleMaxValue(Long.MAX_VALUE);
expression = "CAST(d1 AS BIGINT)";
}
else {
setDoubleMaxValue(Math.pow(9, Integer.valueOf(precision) - SCALE));
expression = String.format("CAST(d1 AS DECIMAL(%s, %d))", precision, SCALE);
}
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object castDoubleToDecimalBenchmark(CastDoubleToDecimalBenchmarkState state)
{
return execute(state);
}
@Test
public void testCastDoubleToDecimalBenchmark()
{
CastDoubleToDecimalBenchmarkState state = new CastDoubleToDecimalBenchmarkState();
state.setup();
castDoubleToDecimalBenchmark(state);
}
@State(Thread)
public static class CastDecimalToDoubleBenchmarkState
extends BaseState
{
private static final int SCALE = 10;
@Param({"15", "35"})
private String precision = "15";
@Setup
public void setup()
{
addSymbol("v1", createDecimalType(Integer.valueOf(precision), SCALE));
String expression = "CAST(v1 AS DOUBLE)";
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object castDecimalToDoubleBenchmark(CastDecimalToDoubleBenchmarkState state)
{
return execute(state);
}
@Test
public void testCastDecimalToDoubleBenchmark()
{
CastDecimalToDoubleBenchmarkState state = new CastDecimalToDoubleBenchmarkState();
state.setup();
castDecimalToDoubleBenchmark(state);
}
@State(Thread)
public static class CastDecimalToVarcharBenchmarkState
extends BaseState
{
private static final int SCALE = 10;
@Param({"15", "35"})
private String precision = "35";
@Setup
public void setup()
{
addSymbol("v1", createDecimalType(Integer.valueOf(precision), SCALE));
String expression = "CAST(v1 AS VARCHAR)";
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object castDecimalToVarcharBenchmark(CastDecimalToVarcharBenchmarkState state)
{
return execute(state);
}
@Test
public void testCastDecimalToVarcharBenchmark()
{
CastDecimalToVarcharBenchmarkState state = new CastDecimalToVarcharBenchmarkState();
state.setup();
castDecimalToVarcharBenchmark(state);
}
@State(Thread)
public static class AdditionBenchmarkState
extends BaseState
{
@Param({"d1 + d2",
"d1 + d2 + d3 + d4",
"s1 + s2",
"s1 + s2 + s3 + s4",
"l1 + l2",
"l1 + l2 + l3 + l4",
"s2 + l3 + l1 + s4"})
private String expression = "d1 + d2";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
addSymbol("d2", DOUBLE);
addSymbol("d3", DOUBLE);
addSymbol("d4", DOUBLE);
addSymbol("s1", createDecimalType(10, 5));
addSymbol("s2", createDecimalType(7, 2));
addSymbol("s3", createDecimalType(12, 2));
addSymbol("s4", createDecimalType(2, 1));
addSymbol("l1", createDecimalType(35, 10));
addSymbol("l2", createDecimalType(25, 5));
addSymbol("l3", createDecimalType(20, 6));
addSymbol("l4", createDecimalType(25, 8));
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object additionBenchmark(AdditionBenchmarkState state)
{
return execute(state);
}
@Test
public void testAdditionBenchmark()
{
AdditionBenchmarkState state = new AdditionBenchmarkState();
state.setup();
additionBenchmark(state);
}
@State(Thread)
public static class MultiplyBenchmarkState
extends BaseState
{
@Param({"d1 * d2",
"d1 * d2 * d3 * d4",
"i1 * i2",
// short short -> short
"s1 * s2",
"s1 * s2 * s5 * s6",
// short short -> long
"s3 * s4",
// long short -> long
"l2 * s2",
"l2 * s2 * s5 * s6",
// short long -> long
"s1 * l2",
// long long -> long
"l1 * l2"})
private String expression = "d1 * d2";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
addSymbol("d2", DOUBLE);
addSymbol("d3", DOUBLE);
addSymbol("d4", DOUBLE);
addSymbol("i1", BIGINT);
addSymbol("i2", BIGINT);
addSymbol("s1", createDecimalType(5, 2));
addSymbol("s2", createDecimalType(3, 1));
addSymbol("s3", createDecimalType(10, 5));
addSymbol("s4", createDecimalType(10, 2));
addSymbol("s5", createDecimalType(3, 2));
addSymbol("s6", createDecimalType(2, 1));
addSymbol("l1", createDecimalType(19, 10));
addSymbol("l2", createDecimalType(19, 5));
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object multiplyBenchmark(MultiplyBenchmarkState state)
{
return execute(state);
}
@Test
public void testMultiplyBenchmark()
{
MultiplyBenchmarkState state = new MultiplyBenchmarkState();
state.setup();
multiplyBenchmark(state);
}
@State(Thread)
public static class DivisionBenchmarkState
extends BaseState
{
@Param({"d1 / d2",
"d1 / d2 / d3 / d4",
"i1 / i2",
"i1 / i2 / i3 / i4",
// short short -> short
"s1 / s2",
"s1 / s2 / s2 / s2",
// short short -> long
"s1 / s3",
// short long -> short
"s2 / l1",
// long short -> long
"l1 / s2",
// short long -> long
"s3 / l1",
// long long -> long
"l2 / l3",
"l2 / l4 / l4 / l4",
"l2 / s4 / s4 / s4"})
private String expression = "d1 / d2";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
addSymbol("d2", DOUBLE);
addSymbol("d3", DOUBLE);
addSymbol("d4", DOUBLE);
addSymbol("i1", BIGINT);
addSymbol("i2", BIGINT);
addSymbol("i3", BIGINT);
addSymbol("i4", BIGINT);
addSymbol("s1", createDecimalType(8, 3));
addSymbol("s2", createDecimalType(6, 2));
addSymbol("s3", createDecimalType(17, 7));
addSymbol("s4", createDecimalType(3, 2));
addSymbol("l1", createDecimalType(19, 3));
addSymbol("l2", createDecimalType(20, 3));
addSymbol("l3", createDecimalType(21, 10));
addSymbol("l4", createDecimalType(19, 4));
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object divisionBenchmark(DivisionBenchmarkState state)
{
return execute(state);
}
@Test
public void testDivisionBenchmark()
{
DivisionBenchmarkState state = new DivisionBenchmarkState();
state.setup();
divisionBenchmark(state);
}
@State(Thread)
public static class ModuloBenchmarkState
extends BaseState
{
@Param({"d1 % d2",
"d1 % d2 % d3 % d4",
"i1 % i2",
"i1 % i2 % i3 % i4",
// short short -> short
"s1 % s2",
"s1 % s2 % s2 % s2",
// short long -> short
"s2 % l2",
// long short -> long
"l3 % s3",
// short long -> long
"s4 % l3",
// long long -> long
"l2 % l3",
"l2 % l3 % l4 % l1"})
private String expression = "d1 % d2";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
addSymbol("d2", DOUBLE);
addSymbol("d3", DOUBLE);
addSymbol("d4", DOUBLE);
addSymbol("i1", BIGINT);
addSymbol("i2", BIGINT);
addSymbol("i3", BIGINT);
addSymbol("i4", BIGINT);
addSymbol("s1", createDecimalType(8, 3));
addSymbol("s2", createDecimalType(6, 2));
addSymbol("s3", createDecimalType(9, 0));
addSymbol("s4", createDecimalType(12, 2));
addSymbol("l1", createDecimalType(19, 3));
addSymbol("l2", createDecimalType(20, 3));
addSymbol("l3", createDecimalType(21, 10));
addSymbol("l4", createDecimalType(19, 4));
generateRandomInputPage();
generateProcessor(expression);
}
}
@Benchmark
public Object moduloBenchmark(ModuloBenchmarkState state)
{
return execute(state);
}
@Test
public void testModuloBenchmark()
{
ModuloBenchmarkState state = new ModuloBenchmarkState();
state.setup();
moduloBenchmark(state);
}
@State(Thread)
public static class InequalityBenchmarkState
extends BaseState
{
@Param({"d1 < d2",
"d1 < d2 AND d1 < d3 AND d1 < d4 AND d2 < d3 AND d2 < d4 AND d3 < d4",
"s1 < s2",
"s1 < s2 AND s1 < s3 AND s1 < s4 AND s2 < s3 AND s2 < s4 AND s3 < s4",
"l1 < l2",
"l1 < l2 AND l1 < l3 AND l1 < l4 AND l2 < l3 AND l2 < l4 AND l3 < l4"})
private String expression = "d1 < d2";
@Setup
public void setup()
{
addSymbol("d1", DOUBLE);
addSymbol("d2", DOUBLE);
addSymbol("d3", DOUBLE);
addSymbol("d4", DOUBLE);
addSymbol("s1", SHORT_DECIMAL_TYPE);
addSymbol("s2", SHORT_DECIMAL_TYPE);
addSymbol("s3", SHORT_DECIMAL_TYPE);
addSymbol("s4", SHORT_DECIMAL_TYPE);
addSymbol("l1", LONG_DECIMAL_TYPE);
addSymbol("l2", LONG_DECIMAL_TYPE);
addSymbol("l3", LONG_DECIMAL_TYPE);
addSymbol("l4", LONG_DECIMAL_TYPE);
generateInputPage(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
generateProcessor(expression);
}
}
@Benchmark
public Object inequalityBenchmark(InequalityBenchmarkState state)
{
return execute(state);
}
@Test
public void testInequalityBenchmark()
{
InequalityBenchmarkState state = new InequalityBenchmarkState();
state.setup();
inequalityBenchmark(state);
}
@State(Thread)
public static class DecimalToShortDecimalCastBenchmarkState
extends BaseState
{
@Param({"cast(l_38_30 as decimal(8, 0))",
"cast(l_26_18 as decimal(8, 0))",
"cast(l_20_12 as decimal(8, 0))",
"cast(l_20_8 as decimal(8, 0))",
"cast(s_17_9 as decimal(8, 0))"})
private String expression = "cast(l_38_30 as decimal(8, 0))";
@Setup
public void setup()
{
addSymbol("l_38_30", createDecimalType(38, 30));
addSymbol("l_26_18", createDecimalType(26, 18));
addSymbol("l_20_12", createDecimalType(20, 12));
addSymbol("l_20_8", createDecimalType(20, 8));
addSymbol("s_17_9", createDecimalType(17, 9));
generateInputPage(10000, 10000, 10000, 10000, 10000);
generateProcessor(expression);
}
}
@Benchmark
public Object decimalToShortDecimalCastBenchmark(DecimalToShortDecimalCastBenchmarkState state)
{
return execute(state);
}
@Test
public void testDecimalToShortDecimalCastBenchmark()
{
DecimalToShortDecimalCastBenchmarkState state = new DecimalToShortDecimalCastBenchmarkState();
state.setup();
decimalToShortDecimalCastBenchmark(state);
}
private Object execute(BaseState state)
{
return ImmutableList.copyOf(state.getProcessor().process(SESSION, state.getInputPage()));
}
private static class BaseState
{
private final MetadataManager metadata = createTestMetadataManager();
private final Session session = testSessionBuilder().build();
private final Random random = new Random();
protected final Map<String, Symbol> symbols = new HashMap<>();
protected final Map<Symbol, Type> symbolTypes = new HashMap<>();
private final Map<Symbol, Integer> sourceLayout = new HashMap<>();
protected final List<Type> types = new LinkedList<>();
protected Page inputPage;
private PageProcessor processor;
private double doubleMaxValue = 2L << 31;
public Page getInputPage()
{
return inputPage;
}
public PageProcessor getProcessor()
{
return processor;
}
protected void addSymbol(String name, Type type)
{
Symbol symbol = new Symbol(name);
symbols.put(name, symbol);
symbolTypes.put(symbol, type);
sourceLayout.put(symbol, types.size());
types.add(type);
}
protected void generateRandomInputPage()
{
RowPagesBuilder buildPagesBuilder = rowPagesBuilder(types);
for (int i = 0; i < PAGE_SIZE; i++) {
Object[] values = types.stream()
.map(this::generateRandomValue)
.collect(toList()).toArray();
buildPagesBuilder.row(values);
}
inputPage = getOnlyElement(buildPagesBuilder.build());
}
protected void generateInputPage(int... initialValues)
{
RowPagesBuilder buildPagesBuilder = rowPagesBuilder(types);
buildPagesBuilder.addSequencePage(PAGE_SIZE, initialValues);
inputPage = getOnlyElement(buildPagesBuilder.build());
}
protected void generateProcessor(String expression)
{
processor = new ExpressionCompiler(metadata).compilePageProcessor(Optional.empty(), ImmutableList.of(rowExpression(expression))).get();
}
protected void setDoubleMaxValue(double doubleMaxValue)
{
this.doubleMaxValue = doubleMaxValue;
}
private RowExpression rowExpression(String expression)
{
Expression inputReferenceExpression = new SymbolToInputRewriter(sourceLayout).rewrite(createExpression(expression, metadata, symbolTypes));
Map<Integer, Type> types = sourceLayout.entrySet().stream()
.collect(toMap(Map.Entry::getValue, entry -> symbolTypes.get(entry.getKey())));
IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList());
return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true);
}
private Object generateRandomValue(Type type)
{
if (type instanceof DoubleType) {
return random.nextDouble() * (2L * doubleMaxValue) - doubleMaxValue;
}
else if (type instanceof DecimalType) {
return randomDecimal((DecimalType) type);
}
else if (type instanceof BigintType) {
int randomInt = random.nextInt();
return randomInt == 0 ? 1 : randomInt;
}
throw new UnsupportedOperationException(type.toString());
}
private SqlDecimal randomDecimal(DecimalType type)
{
int maxBits = (int) (Math.log(Math.pow(10, type.getPrecision())) / Math.log(2));
BigInteger bigInteger = new BigInteger(maxBits, random);
if (bigInteger.equals(ZERO)) {
bigInteger = ONE;
}
if (random.nextBoolean()) {
bigInteger = bigInteger.negate();
}
return new SqlDecimal(bigInteger, type.getPrecision(), type.getScale());
}
}
public static void main(String[] args)
throws RunnerException
{
Options options = new OptionsBuilder()
.verbosity(VerboseMode.NORMAL)
.include(".*" + BenchmarkDecimalOperators.class.getSimpleName() + ".*")
.build();
new Runner(options).run();
}
}