/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.sanity; import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.SimplePlanVisitor; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ListMultimap; import java.util.List; import java.util.Map; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** * Ensures that all the expressions and FunctionCalls matches their output symbols */ public final class TypeValidator implements PlanSanityChecker.Checker { public TypeValidator() {} @Override public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, Map<Symbol, Type> types) { plan.accept(new Visitor(session, metadata, sqlParser, types), null); } private static class Visitor extends SimplePlanVisitor<Void> { private final Session session; private final Metadata metadata; private final SqlParser sqlParser; private final Map<Symbol, Type> types; public Visitor(Session session, Metadata metadata, SqlParser sqlParser, Map<Symbol, Type> types) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.types = requireNonNull(types, "types is null"); } @Override public Void visitAggregation(AggregationNode node, Void context) { visitPlan(node, context); AggregationNode.Step step = node.getStep(); switch (step) { case SINGLE: checkFunctionSignature(node.getFunctions()); checkFunctionCall(node.getAggregations()); break; case FINAL: checkFunctionSignature(node.getFunctions()); break; } return null; } @Override public Void visitWindow(WindowNode node, Void context) { visitPlan(node, context); checkWindowFunctions(node.getWindowFunctions()); return null; } @Override public Void visitProject(ProjectNode node, Void context) { visitPlan(node, context); for (Map.Entry<Symbol, Expression> entry : node.getAssignments().entrySet()) { Type expectedType = types.get(entry.getKey()); if (entry.getValue() instanceof SymbolReference) { SymbolReference symbolReference = (SymbolReference) entry.getValue(); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } Type actualType = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList() /* parameters already replaced */).get(entry.getValue()); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } return null; } @Override public Void visitUnion(UnionNode node, Void context) { visitPlan(node, context); ListMultimap<Symbol, Symbol> symbolMapping = node.getSymbolMapping(); for (Symbol keySymbol : symbolMapping.keySet()) { List<Symbol> valueSymbols = symbolMapping.get(keySymbol); Type expectedType = types.get(keySymbol); for (Symbol valueSymbol : valueSymbols) { verifyTypeSignature(keySymbol, expectedType.getTypeSignature(), types.get(valueSymbol).getTypeSignature()); } } return null; } private void checkWindowFunctions(Map<Symbol, WindowNode.Function> functions) { for (Map.Entry<Symbol, WindowNode.Function> entry : functions.entrySet()) { Signature signature = entry.getValue().getSignature(); FunctionCall call = entry.getValue().getFunctionCall(); checkSignature(entry.getKey(), signature); checkCall(entry.getKey(), call); } } private void checkSignature(Symbol symbol, Signature signature) { TypeSignature expectedTypeSignature = types.get(symbol).getTypeSignature(); TypeSignature actualTypeSignature = signature.getReturnType(); verifyTypeSignature(symbol, expectedTypeSignature, actualTypeSignature); } private void checkCall(Symbol symbol, FunctionCall call) { Type expectedType = types.get(symbol); Type actualType = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList() /*parameters already replaced */).get(call); verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); } private void checkFunctionSignature(Map<Symbol, Signature> functions) { for (Map.Entry<Symbol, Signature> entry : functions.entrySet()) { checkSignature(entry.getKey(), entry.getValue()); } } private void checkFunctionCall(Map<Symbol, FunctionCall> functionCalls) { for (Map.Entry<Symbol, FunctionCall> entry : functionCalls.entrySet()) { checkCall(entry.getKey(), entry.getValue()); } } private void verifyTypeSignature(Symbol symbol, TypeSignature expected, TypeSignature actual) { // UNKNOWN should be considered as a wildcard type, which matches all the other types TypeManager typeManager = metadata.getTypeManager(); if (!actual.equals(UNKNOWN.getTypeSignature()) && !typeManager.isTypeOnlyCoercion(typeManager.getType(actual), typeManager.getType(expected))) { checkArgument(expected.equals(actual), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol, expected, actual); } } } }