/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.planner.sanity.TypeValidator;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FrameBound;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.WindowFrame;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DateType.DATE;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.IntegerType.INTEGER;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
@Test(singleThreaded = true)
public class TestTypeValidator
{
private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle());
private static final SqlParser SQL_PARSER = new SqlParser();
private static final TypeValidator TYPE_VALIDATOR = new TypeValidator();
private SymbolAllocator symbolAllocator;
private TableScanNode baseTableScan;
private Symbol columnA;
private Symbol columnB;
private Symbol columnC;
private Symbol columnD;
private Symbol columnE;
@BeforeMethod
public void setUp()
{
symbolAllocator = new SymbolAllocator();
columnA = symbolAllocator.newSymbol("a", BIGINT);
columnB = symbolAllocator.newSymbol("b", INTEGER);
columnC = symbolAllocator.newSymbol("c", DOUBLE);
columnD = symbolAllocator.newSymbol("d", DATE);
columnE = symbolAllocator.newSymbol("e", VarcharType.createVarcharType(3)); // varchar(3), to test type only coercion
Map<Symbol, ColumnHandle> assignments = ImmutableMap.<Symbol, ColumnHandle>builder()
.put(columnA, new TestingColumnHandle("a"))
.put(columnB, new TestingColumnHandle("b"))
.put(columnC, new TestingColumnHandle("c"))
.put(columnD, new TestingColumnHandle("d"))
.put(columnE, new TestingColumnHandle("e"))
.build();
baseTableScan = new TableScanNode(
newId(),
TEST_TABLE_HANDLE,
ImmutableList.copyOf(assignments.keySet()),
assignments,
Optional.empty(),
TupleDomain.all(),
null);
}
@Test
public void testValidProject()
throws Exception
{
Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT);
Expression expression2 = new Cast(columnC.toSymbolReference(), StandardTypes.BIGINT);
Assignments assignments = Assignments.builder()
.put(symbolAllocator.newSymbol(expression1, BIGINT), expression1)
.put(symbolAllocator.newSymbol(expression2, BIGINT), expression2)
.build();
PlanNode node = new ProjectNode(
newId(),
baseTableScan,
assignments);
assertTypesValid(node);
}
@Test
public void testValidUnion()
throws Exception
{
Symbol outputSymbol = symbolAllocator.newSymbol("output", DATE);
ListMultimap<Symbol, Symbol> mappings = ImmutableListMultimap.<Symbol, Symbol>builder()
.put(outputSymbol, columnD)
.put(outputSymbol, columnD)
.build();
PlanNode node = new UnionNode(
newId(),
ImmutableList.of(baseTableScan, baseTableScan),
mappings,
ImmutableList.copyOf(mappings.keySet()));
assertTypesValid(node);
}
@Test
public void testValidWindow()
throws Exception
{
Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Signature signature = new Signature(
"sum",
FunctionKind.WINDOW,
ImmutableList.of(),
ImmutableList.of(),
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()));
WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
FrameBound.Type.UNBOUNDED_PRECEDING,
Optional.empty(),
FrameBound.Type.UNBOUNDED_FOLLOWING,
Optional.empty());
WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of());
PlanNode node = new WindowNode(
newId(),
baseTableScan,
specification,
ImmutableMap.of(windowSymbol, function),
Optional.empty(),
ImmutableSet.of(),
0);
assertTypesValid(node);
}
@Test
public void testValidAggregation()
throws Exception
{
Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Map<Symbol, Signature> functions = ImmutableMap.of(
aggregationSymbol, new Signature(
"sum",
FunctionKind.AGGREGATE,
ImmutableList.of(),
ImmutableList.of(),
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false));
Map<Symbol, FunctionCall> aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())));
PlanNode node = new AggregationNode(
newId(),
baseTableScan,
aggregations,
functions,
ImmutableMap.of(),
ImmutableList.of(ImmutableList.of(columnA, columnB)),
SINGLE,
Optional.empty(),
Optional.empty());
assertTypesValid(node);
}
@Test
public void testValidTypeOnlyCoercion()
throws Exception
{
Expression expression = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT);
Assignments assignments = Assignments.builder()
.put(symbolAllocator.newSymbol(expression, BIGINT), expression)
.put(symbolAllocator.newSymbol(columnE.toSymbolReference(), VARCHAR), columnE.toSymbolReference()) // implicit coercion from varchar(3) to varchar
.build();
PlanNode node = new ProjectNode(newId(), baseTableScan, assignments);
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer")
public void testInvalidProject()
throws Exception
{
Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.INTEGER);
Expression expression2 = new Cast(columnA.toSymbolReference(), StandardTypes.INTEGER);
Assignments assignments = Assignments.builder()
.put(symbolAllocator.newSymbol(expression1, BIGINT), expression1) // should be INTEGER
.put(symbolAllocator.newSymbol(expression1, INTEGER), expression2)
.build();
PlanNode node = new ProjectNode(
newId(),
baseTableScan,
assignments);
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidAggregationFunctionCall()
throws Exception
{
Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Map<Symbol, Signature> functions = ImmutableMap.of(
aggregationSymbol, new Signature(
"sum",
FunctionKind.AGGREGATE,
ImmutableList.of(),
ImmutableList.of(),
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false));
Map<Symbol, FunctionCall> aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference()))); // should be columnC
PlanNode node = new AggregationNode(
newId(),
baseTableScan,
aggregations,
functions,
ImmutableMap.of(),
ImmutableList.of(ImmutableList.of(columnA, columnB)),
SINGLE,
Optional.empty(),
Optional.empty());
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidAggregationFunctionSignature()
throws Exception
{
Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Map<Symbol, Signature> functions = ImmutableMap.of(
aggregationSymbol, new Signature(
"sum",
FunctionKind.AGGREGATE,
ImmutableList.of(),
ImmutableList.of(),
BIGINT.getTypeSignature(), // should be DOUBLE
ImmutableList.of(DOUBLE.getTypeSignature()),
false));
Map<Symbol, FunctionCall> aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())));
PlanNode node = new AggregationNode(
newId(),
baseTableScan,
aggregations,
functions,
ImmutableMap.of(),
ImmutableList.of(ImmutableList.of(columnA, columnB)),
SINGLE,
Optional.empty(),
Optional.empty());
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidWindowFunctionCall()
throws Exception
{
Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Signature signature = new Signature(
"sum",
FunctionKind.WINDOW,
ImmutableList.of(),
ImmutableList.of(),
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference())); // should be columnC
WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
FrameBound.Type.UNBOUNDED_PRECEDING,
Optional.empty(),
FrameBound.Type.UNBOUNDED_FOLLOWING,
Optional.empty());
WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of());
PlanNode node = new WindowNode(
newId(),
baseTableScan,
specification,
ImmutableMap.of(windowSymbol, function),
Optional.empty(),
ImmutableSet.of(),
0);
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidWindowFunctionSignature()
throws Exception
{
Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Signature signature = new Signature(
"sum",
FunctionKind.WINDOW,
ImmutableList.of(),
ImmutableList.of(),
BIGINT.getTypeSignature(), // should be DOUBLE
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()));
WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
FrameBound.Type.UNBOUNDED_PRECEDING,
Optional.empty(),
FrameBound.Type.UNBOUNDED_FOLLOWING,
Optional.empty());
WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of());
PlanNode node = new WindowNode(
newId(),
baseTableScan,
specification,
ImmutableMap.of(windowSymbol, function),
Optional.empty(),
ImmutableSet.of(),
0);
assertTypesValid(node);
}
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint")
public void testInvalidUnion()
throws Exception
{
Symbol outputSymbol = symbolAllocator.newSymbol("output", DATE);
ListMultimap<Symbol, Symbol> mappings = ImmutableListMultimap.<Symbol, Symbol>builder()
.put(outputSymbol, columnD)
.put(outputSymbol, columnA) // should be a symbol with DATE type
.build();
PlanNode node = new UnionNode(
newId(),
ImmutableList.of(baseTableScan, baseTableScan),
mappings,
ImmutableList.copyOf(mappings.keySet()));
assertTypesValid(node);
}
private void assertTypesValid(PlanNode node)
{
TYPE_VALIDATOR.validate(node, TEST_SESSION, createTestMetadataManager(), SQL_PARSER, symbolAllocator.getTypes());
}
private static PlanNodeId newId()
{
return new PlanNodeId(UUID.randomUUID().toString());
}
}