/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.type;
import com.facebook.presto.sql.parser.CaseInsensitiveStream;
import com.facebook.presto.sql.parser.ParsingException;
import com.facebook.presto.type.TypeCalculationParser.ArithmeticBinaryContext;
import com.facebook.presto.type.TypeCalculationParser.ArithmeticUnaryContext;
import com.facebook.presto.type.TypeCalculationParser.BinaryFunctionContext;
import com.facebook.presto.type.TypeCalculationParser.IdentifierContext;
import com.facebook.presto.type.TypeCalculationParser.NullLiteralContext;
import com.facebook.presto.type.TypeCalculationParser.NumericLiteralContext;
import com.facebook.presto.type.TypeCalculationParser.ParenthesizedExpressionContext;
import com.facebook.presto.type.TypeCalculationParser.TypeCalculationContext;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import java.math.BigInteger;
import java.util.Map;
import static com.facebook.presto.type.TypeCalculationParser.ASTERISK;
import static com.facebook.presto.type.TypeCalculationParser.MAX;
import static com.facebook.presto.type.TypeCalculationParser.MIN;
import static com.facebook.presto.type.TypeCalculationParser.MINUS;
import static com.facebook.presto.type.TypeCalculationParser.PLUS;
import static com.facebook.presto.type.TypeCalculationParser.SLASH;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
public final class TypeCalculation
{
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener()
{
@Override
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line, int charPositionInLine, String message, RecognitionException e)
{
throw new ParsingException(message, e, line, charPositionInLine);
}
};
private TypeCalculation() {}
public static Long calculateLiteralValue(
String calculation,
Map<String, Long> inputs)
{
try {
ParserRuleContext tree = parseTypeCalculation(calculation);
CalculateTypeVisitor visitor = new CalculateTypeVisitor(inputs);
BigInteger result = visitor.visit(tree);
return result.longValueExact();
}
catch (StackOverflowError e) {
throw new ParsingException("Type calculation is too large (stack overflow while parsing)");
}
}
private static ParserRuleContext parseTypeCalculation(String calculation)
{
TypeCalculationLexer lexer = new TypeCalculationLexer(new CaseInsensitiveStream(new ANTLRInputStream(calculation)));
CommonTokenStream tokenStream = new CommonTokenStream(lexer);
TypeCalculationParser parser = new TypeCalculationParser(tokenStream);
lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);
parser.removeErrorListeners();
parser.addErrorListener(ERROR_LISTENER);
ParserRuleContext tree;
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
tree = parser.typeCalculation();
}
catch (ParseCancellationException ex) {
// if we fail, parse with LL mode
tokenStream.reset(); // rewind input stream
parser.reset();
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
tree = parser.typeCalculation();
}
return tree;
}
private static class IsSimpleExpressionVisitor
extends TypeCalculationBaseVisitor<Boolean>
{
@Override
public Boolean visitArithmeticBinary(ArithmeticBinaryContext ctx)
{
return false;
}
@Override
public Boolean visitArithmeticUnary(ArithmeticUnaryContext ctx)
{
return false;
}
@Override
protected Boolean defaultResult()
{
return true;
}
@Override
protected Boolean aggregateResult(Boolean aggregate, Boolean nextResult)
{
return aggregate && nextResult;
}
}
private static class CalculateTypeVisitor
extends TypeCalculationBaseVisitor<BigInteger>
{
private final Map<String, Long> inputs;
public CalculateTypeVisitor(Map<String, Long> inputs)
{
this.inputs = requireNonNull(inputs);
}
@Override
public BigInteger visitTypeCalculation(TypeCalculationContext ctx)
{
return visit(ctx.expression());
}
@Override
public BigInteger visitArithmeticBinary(ArithmeticBinaryContext ctx)
{
BigInteger left = visit(ctx.left);
BigInteger right = visit(ctx.right);
switch (ctx.operator.getType()) {
case PLUS:
return left.add(right);
case MINUS:
return left.subtract(right);
case ASTERISK:
return left.multiply(right);
case SLASH:
return left.divide(right);
default:
throw new IllegalStateException("Unsupported binary operator " + ctx.operator.getText());
}
}
@Override
public BigInteger visitArithmeticUnary(ArithmeticUnaryContext ctx)
{
BigInteger value = visit(ctx.expression());
switch (ctx.operator.getType()) {
case PLUS:
return value;
case MINUS:
return value.negate();
default:
throw new IllegalStateException("Unsupported unary operator " + ctx.operator.getText());
}
}
@Override
public BigInteger visitBinaryFunction(BinaryFunctionContext ctx)
{
BigInteger left = visit(ctx.left);
BigInteger right = visit(ctx.right);
switch (ctx.binaryFunctionName().name.getType()) {
case MIN:
return left.min(right);
case MAX:
return left.max(right);
default:
throw new IllegalArgumentException("Unsupported binary function " + ctx.binaryFunctionName().getText());
}
}
@Override
public BigInteger visitNumericLiteral(NumericLiteralContext ctx)
{
return new BigInteger(ctx.INTEGER_VALUE().getText());
}
@Override
public BigInteger visitNullLiteral(NullLiteralContext ctx)
{
return BigInteger.ZERO;
}
@Override
public BigInteger visitIdentifier(IdentifierContext ctx)
{
String identifier = ctx.getText();
Long value = inputs.get(identifier);
checkState(value != null, "value for variable '%s' is not specified in the inputs", identifier);
return BigInteger.valueOf(value);
}
@Override
public BigInteger visitParenthesizedExpression(ParenthesizedExpressionContext ctx)
{
return visit(ctx.expression());
}
}
}