package org.rakam.automation; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.AstVisitor; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.IsNotNullPredicate; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.QualifiedNameReference; import com.facebook.presto.sql.tree.StringLiteral; import com.google.common.base.Throwables; import net.openhft.compiler.CompilerUtils; import org.rakam.collection.Event; import java.util.function.Predicate; import static org.rakam.util.ValidationUtil.checkTableColumn; public final class ExpressionCompiler { private static final SqlParser sqlParser = new SqlParser(); private ExpressionCompiler() throws InstantiationException { throw new InstantiationException("The class is not created for instantiation"); } public static Predicate<Event> compile(String expressionStr) throws UnsupportedOperationException { final Expression expression; synchronized (sqlParser) { expression = sqlParser.createExpression(expressionStr); } final String javaExp = new JavaSourceAstVisitor().process(expression, false); String className = "org.rakam.automation.compiled.Predicate1"; String javaCode = String.format("package org.rakam.automation.compiled;\n" + "import org.rakam.collection.Event;\n" + "import org.apache.avro.generic.GenericRecord;\n" + "import java.lang.Comparable;\n" + "import java.util.function.Predicate;\n" + "public class Predicate1 implements Predicate<Event> {\n" + " public boolean test(Event event) {\n" + " GenericRecord props = event.properties();\n" + " return %s;\n" + " }\n" + "}\n", javaExp); try { Class aClass = CompilerUtils.CACHED_COMPILER.loadFromJava(className, javaCode); return (Predicate) aClass.newInstance(); } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { throw Throwables.propagate(e); } } private static class JavaSourceAstVisitor extends AstVisitor<String, Boolean> { public String visitLogicalBinaryExpression(LogicalBinaryExpression node, Boolean context) { return formatBinaryExpression(getLogicalContext(node.getType()), node.getLeft(), node.getRight(), context); } @Override protected String visitComparisonExpression(ComparisonExpression node, Boolean context) { return String.format(getComparisonFormat(node.getType()), process(node.getRight(), context), process(node.getLeft(), context)); } private String formatBinaryExpression(String operator, Expression left, Expression right, boolean unmangleNames) { return '(' + process(left, unmangleNames) + ' ' + operator + ' ' + process(right, unmangleNames) + ')'; } @Override protected String visitLikePredicate(LikePredicate node, Boolean context) { StringBuilder builder = new StringBuilder(); // TODO: handle this in a proper way. if (!(node.getPattern() instanceof StringLiteral)) { throw new UnsupportedOperationException(); } String value = ((StringLiteral) node.getPattern()).getValue(); String process = process(node.getValue(), context); builder.append('(') .append(process).append(" instanceof String && ((String) ").append(process).append(")."); boolean starts = false; boolean ends = false; int length = value.length(); for (int i = -1; (i = value.indexOf('%', i + 1)) != -1; ) { if (i == 0) { starts = true; } else if (i + 1 == length) { ends = true; } else { throw new UnsupportedOperationException(); } } if (starts && ends) { builder.append("contains(\"").append(value.substring(1, length - 1)).append("\")"); } else if (ends) { builder.append("endsWith(\"").append(value.substring(0, length - 1)).append("\")"); } else if (starts) { builder.append("startsWith(\"").append(value.substring(1, length - 2)).append("\")"); } if (node.getEscape() != null) { throw new UnsupportedOperationException(); } builder.append(')'); return builder.toString(); } @Override protected String visitQualifiedNameReference(QualifiedNameReference node, Boolean unmangleNames) { if (node.getName().getPrefix().isPresent()) { throw new IllegalArgumentException("field reference is invalid"); } final String suffix = node.getName().getSuffix(); return "props.get(\"" + checkTableColumn(suffix, "field reference is invalid", '"') + "\")"; } @Override protected String visitLiteral(Literal node, Boolean context) { return node.toString(); } private String getLogicalContext(LogicalBinaryExpression.Type type) { switch (type) { case AND: return "&&"; case OR: return "||"; default: throw new IllegalStateException(); } } private String getComparisonFormat(ComparisonExpressionType type) { switch (type) { case EQUAL: return "%2$s instanceof Comparable && ((Comparable) %2$s).equals(%1$s)"; case NOT_EQUAL: return "%2$s instanceof Comparable && !((Comparable) %2$s).equals(%1$s)"; case LESS_THAN: return "%2$s instanceof Comparable && ((Comparable) %2$s).compareTo(%1$s) > 0"; case GREATER_THAN: return "%2$s instanceof Comparable && ((Comparable) %2$s).compareTo(%1$s) < 0"; case GREATER_THAN_OR_EQUAL: return "%2$s instanceof Comparable && ((Comparable) %2$s).compareTo(%1$s) <= 0"; case LESS_THAN_OR_EQUAL: return "%2$s instanceof Comparable && ((Comparable) %2$s).compareTo(%1$s) >= 0"; default: throw new IllegalStateException(); } } @Override protected String visitIsNotNullPredicate(IsNotNullPredicate node, Boolean context) { return process(node.getValue(), context) + " != null"; } @Override protected String visitIsNullPredicate(IsNullPredicate node, Boolean context) { return process(node.getValue(), context) + " == null"; } @Override protected String visitNode(Node node, Boolean context) { throw new UnsupportedOperationException(); } } }