/* * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor * license agreements. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. Crate licenses * this file to you under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial agreement. */ package io.crate.operation; import com.google.common.collect.ImmutableMap; import io.crate.analyze.symbol.*; import io.crate.data.Input; import io.crate.data.Row; import io.crate.data.RowN; import io.crate.metadata.FunctionIdent; import io.crate.metadata.FunctionImplementation; import io.crate.metadata.FunctionInfo; import io.crate.operation.aggregation.FunctionExpression; import io.crate.operation.collect.CollectExpression; import io.crate.operation.scalar.arithmetic.ArithmeticFunctions; import io.crate.test.integration.CrateUnitTest; import io.crate.testing.SqlExpressions; import io.crate.testing.T3; import io.crate.types.DataTypes; import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.sameInstance; import static org.hamcrest.core.Is.is; public class InputFactoryTest extends CrateUnitTest { private SqlExpressions expressions = new SqlExpressions(ImmutableMap.of(T3.T1, T3.TR_1), T3.TR_1); private InputFactory factory = new InputFactory(expressions.functions()); @Test public void testAggregationSymbolsInputReuse() throws Exception { Function countX = (Function) expressions.asSymbol("count(x)"); Function avgX = (Function) expressions.asSymbol("avg(x)"); List<Symbol> aggregations = Arrays.asList( new Aggregation(countX.info(), countX.info().returnType(), Arrays.asList(new InputColumn(0))), new Aggregation(avgX.info(), countX.info().returnType(), Arrays.asList(new InputColumn(0))) ); InputFactory.Context<CollectExpression<Row, ?>> ctx = factory.ctxForAggregations(); ctx.add(aggregations); List<AggregationContext> aggregationContexts = ctx.aggregations(); Input<?> inputCount = aggregationContexts.get(0).inputs()[0]; Input<?> inputAverage = aggregationContexts.get(1).inputs()[0]; assertSame(inputCount, inputAverage); } @Test public void testProcessGroupByProjectionSymbols() throws Exception { // select x, y * 2 ... group by x, y * 2 // keys: [ in(0), in(1) + 10 ] Function add = ArithmeticFunctions.of( ArithmeticFunctions.Names.ADD, new InputColumn(1, DataTypes.INTEGER), Literal.of(10), FunctionInfo.DETERMINISTIC_AND_COMPARISON_REPLACEMENT ); List<Symbol> keys = Arrays.asList(new InputColumn(0, DataTypes.LONG), add); InputFactory.Context<CollectExpression<Row, ?>> ctx = factory.ctxForAggregations(); ctx.add(keys); ArrayList<CollectExpression<Row, ?>> expressions = new ArrayList<>(ctx.expressions()); assertThat(expressions.size(), is(2)); // keyExpressions: [ in0, in1 ] RowN row = new RowN(new Object[]{1L, 2L}); for (CollectExpression<Row, ?> expression : expressions) { expression.setNextRow(row); } assertThat((Long) expressions.get(0).value(), is(1L)); assertThat((Long) expressions.get(1).value(), is(2L)); // raw input value // inputs: [ x, add ] List<Input<?>> inputs = ctx.topLevelInputs(); assertThat(inputs.size(), is(2)); assertThat((Long) inputs.get(0).value(), is(1L)); assertThat((Long) inputs.get(1).value(), is(12L)); // + 10 } @Test public void testProcessGroupByProjectionSymbolsAggregation() throws Exception { // select count(x), x, y * 2 ... group by x, y * 2 // keys: [ in(0), in(1) + 10 ] Function add = ArithmeticFunctions.of( ArithmeticFunctions.Names.ADD, new InputColumn(1, DataTypes.INTEGER), Literal.of(10), FunctionInfo.DETERMINISTIC_AND_COMPARISON_REPLACEMENT); List<Symbol> keys = Arrays.asList(new InputColumn(0, DataTypes.LONG), add); Function countX = (Function) expressions.asSymbol("count(x)"); // values: [ count(in(0)) ] List<Aggregation> values = Arrays.asList(new Aggregation( countX.info(), countX.valueType(), Arrays.<Symbol>asList(new InputColumn(0)) )); InputFactory.Context<CollectExpression<Row, ?>> ctx = factory.ctxForAggregations(); ctx.add(keys); // inputs: [ x, add ] List<Input<?>> keyInputs = ctx.topLevelInputs(); ctx.add(values); List<AggregationContext> aggregations = ctx.aggregations(); assertThat(aggregations.size(), is(1)); // collectExpressions: [ in0, in1 ] List<CollectExpression<Row, ?>> expressions = new ArrayList<>(ctx.expressions()); assertThat(expressions.size(), is(2)); List<Input<?>> allInputs = ctx.topLevelInputs(); assertThat(allInputs.size(), is(2)); // only 2 because count is no input RowN row = new RowN(new Object[]{1L, 2L}); for (CollectExpression<Row, ?> expression : expressions) { expression.setNextRow(row); } assertThat((Long) expressions.get(0).value(), is(1L)); assertThat((Long) expressions.get(1).value(), is(2L)); // raw input value assertThat(keyInputs.size(), is(2)); assertThat((Long) keyInputs.get(0).value(), is(1L)); assertThat((Long) keyInputs.get(1).value(), is(12L)); // 2 + 10 } @Test public void testCompiled() throws Exception { Function function = (Function) expressions.normalize(expressions.asSymbol("a like 'f%'")); InputFactory.Context<Input<?>> ctx = factory.ctxForRefs(i -> Literal.of("foo")); Input<?> input = ctx.add(function); FunctionExpression expression = (FunctionExpression) input; java.lang.reflect.Field f = FunctionExpression.class.getDeclaredField("functionImplementation"); f.setAccessible(true); FunctionImplementation impl = (FunctionImplementation) f.get(expression); assertThat(impl.info(), is(function.info())); FunctionIdent ident = function.info().ident(); FunctionImplementation uncompiled = expressions.functions().getBuiltin(ident.name(), ident.argumentTypes()); assertThat(uncompiled, not(sameInstance(impl))); } }