package org.rakam.analysis.stream;
import com.facebook.presto.Session;
import com.facebook.presto.bytecode.Access;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.CompilerUtils;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.block.BlockEncodingSerde;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.transaction.IsolationLevel;
import com.facebook.presto.spi.type.TimeZoneKey;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.gen.CachedInstanceBinder;
import com.facebook.presto.sql.gen.CallSiteBinder;
import com.facebook.presto.sql.gen.CursorProcessorCompiler;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolToInputRewriter;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.relational.optimizer.ExpressionOptimizer;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import org.apache.avro.generic.GenericRecord;
import javax.inject.Inject;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.CompilerUtils.defineClass;
import static com.facebook.presto.sql.ExpressionUtils.rewriteQualifiedNamesToSymbolReferences;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Iterables.concat;
import static java.util.Collections.singleton;
public class ExpressionCompiler
{
private final BlockEncodingSerde serde;
private final Metadata metadata;
private final SqlParser sqlParser = new SqlParser();
private final Session session;
private final TypeManager typeManager;
private final FeaturesConfig featuresConfig;
private final ExpressionOptimizer expressionOptimizer;
@Inject
public ExpressionCompiler(Metadata metadata, TransactionManager transactionManager, FeaturesConfig featuresConfig)
{
this.serde = metadata.getBlockEncodingSerde();
this.metadata = metadata;
this.featuresConfig = featuresConfig;
this.typeManager = metadata.getTypeManager();
this.session = Session.builder(new SessionPropertyManager())
.setIdentity(new Identity("user", Optional.empty()))
.setTimeZoneKey(TimeZoneKey.UTC_KEY)
.setLocale(Locale.ENGLISH)
.setQueryId(QueryId.valueOf("row_expression_compiler"))
.setTransactionId(transactionManager.beginTransaction(IsolationLevel.REPEATABLE_READ, true, true))
.build();
this.expressionOptimizer = new ExpressionOptimizer(metadata.getFunctionRegistry(), metadata.getTypeManager(), session);
}
public Predicate<GenericRecord> generate(Expression expression, List<Map.Entry<String, Type>> columns)
{
FilterContext filterContext = analyze(expression, columns);
ImmutableList<Type> types = copyOf(filterContext.sourceTypes.values());
AvroRecordCursor cursor = new AvroRecordCursor(types, filterContext.projections);
Filter filter = filterContext.filter;
ConnectorSession connectorSession = session.toConnectorSession();
return genericRecord -> {
cursor.setRecord(genericRecord);
return filter.filter(connectorSession, cursor);
};
}
private static class FilterContext
{
public final Map<Integer, Type> sourceTypes;
public final int[] projections;
public final Filter filter;
private FilterContext(Map<Integer, Type> sourceTypes, int[] projections, Filter filter)
{
this.sourceTypes = sourceTypes;
this.projections = projections;
this.filter = filter;
}
}
private FilterContext analyze(Expression filterExpression, List<Map.Entry<String, Type>> columns)
{
filterExpression = rewriteQualifiedNamesToSymbolReferences(filterExpression);
List<Expression> projectionExpressions = new ArrayList<>();
Map<Symbol, Integer> sourceLayout = new HashMap<>();
Map<Integer, Type> sourceTypes = new HashMap<>();
int[] projectionProxies = new int[columns.size()];
new DefaultTraversalVisitor<Void, Void>()
{
@Override
protected Void visitSymbolReference(SymbolReference node, Void context)
{
projectionExpressions.add(node);
int idx = sourceLayout.size();
sourceLayout.put(new Symbol(node.getName().toString()), idx);
int index = 0;
for (int i = 0; i < columns.size(); i++) {
Map.Entry<String, Type> entry = columns.get(i);
if (entry.getKey().equals(node.getName())) {
sourceTypes.put(index++, entry.getValue());
projectionProxies[idx] = i;
break;
}
}
return null;
}
}.process(filterExpression, null);
// compiler uses inputs instead of symbols, so rewrite the expressions first
SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout);
Expression rewrittenFilter = ExpressionTreeRewriter.rewriteWith(symbolToInputRewriter, filterExpression);
List<Expression> rewrittenProjections = new ArrayList<>();
for (Expression projection : projectionExpressions) {
rewrittenProjections.add(ExpressionTreeRewriter.rewriteWith(symbolToInputRewriter, projection));
}
IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput(
session,
metadata,
sqlParser,
sourceTypes,
concat(singleton(rewrittenFilter), rewrittenProjections),
ImmutableList.of());
FunctionRegistry functionRegistry = new FunctionRegistry(typeManager, serde, featuresConfig);
RowExpression filter = SqlToRowExpressionTranslator.translate(
rewrittenFilter, FunctionKind.SCALAR, expressionTypes, functionRegistry, typeManager, session, true);
filter = expressionOptimizer.optimize(filter);
return new FilterContext(sourceTypes, projectionProxies, compileRowExpression(filter));
}
private Filter compileRowExpression(RowExpression filter)
{
ParameterizedType className = CompilerUtils.makeClassName(Filter.class.getSimpleName());
ParameterizedType[] interfaces = {ParameterizedType.type(Filter.class)};
ParameterizedType type = ParameterizedType.type(Object.class);
EnumSet<Access> accessList = a(new Access[] {PUBLIC, FINAL});
ClassDefinition classDefinition = new ClassDefinition(accessList, className, type, interfaces);
classDefinition.declareDefaultConstructor(a(PUBLIC));
CallSiteBinder callSiteBinder = new CallSiteBinder();
CursorProcessorCompiler cursorProcessorCompiler = new CursorProcessorCompiler(metadata);
Method method;
try {
method = cursorProcessorCompiler.getClass().getDeclaredMethod("generateFilterMethod", ClassDefinition.class, CallSiteBinder.class, CachedInstanceBinder.class, RowExpression.class);
method.setAccessible(true);
}
catch (NoSuchMethodException e) {
throw Throwables.propagate(e);
}
try {
method.invoke(cursorProcessorCompiler, classDefinition, callSiteBinder,
new CachedInstanceBinder(classDefinition, callSiteBinder), filter);
}
catch (IllegalAccessException | InvocationTargetException e) {
throw Throwables.propagate(e);
}
Class<? extends Filter> aClass = defineClass(classDefinition, Filter.class, callSiteBinder.getBindings(), getClass().getClassLoader());
try {
return aClass.newInstance();
}
catch (InstantiationException | IllegalAccessException e) {
throw new RuntimeException("Couldn't compile expression", e);
}
}
public interface Filter
{
boolean filter(ConnectorSession session, RecordCursor cursor);
}
}