package org.quaere.jpa; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.persistence.EntityManager; import javax.persistence.Query; import org.quaere.QueryEngine; import org.quaere.Queryable; import org.quaere.expressions.BinaryExpression; import org.quaere.expressions.Constant; import org.quaere.expressions.DeclareClause; import org.quaere.expressions.Expression; import org.quaere.expressions.ExpressionTreeVisitor; import org.quaere.expressions.FromClause; import org.quaere.expressions.GroupClause; import org.quaere.expressions.Identifier; import org.quaere.expressions.Indexer; import org.quaere.expressions.JoinClause; import org.quaere.expressions.MethodCall; import org.quaere.expressions.NewExpression; import org.quaere.expressions.OrderByClause; import org.quaere.expressions.Parameter; import org.quaere.expressions.QueryBody; import org.quaere.expressions.QueryBodyClause; import org.quaere.expressions.QueryContinuation; import org.quaere.expressions.QueryExpression; import org.quaere.expressions.SelectClause; import org.quaere.expressions.Statement; import org.quaere.expressions.TernaryExpression; import org.quaere.expressions.UnaryExpression; import org.quaere.expressions.WhereClause; public class QuaereForJPAQueryEngine implements ExpressionTreeVisitor, QueryEngine { private final EntityManager entityManager; private List<String> sourceNames = new ArrayList<String>(); private Map<String, QueryableEntity> sources = new HashMap<String, QueryableEntity>(); private StringBuilder selectFragment = new StringBuilder(); private StringBuilder fromFragment = new StringBuilder(); private StringBuilder whereFragment = new StringBuilder(); private String currentFragment; public QuaereForJPAQueryEngine(EntityManager entityManager) { this.entityManager = entityManager; } public void visit(FromClause expression) { sourceNames.add(expression.identifier.name); expression.sourceExpression.accept(this); if (sources.containsKey(currentFragment)) { currentFragment = sources.get(currentFragment).getEntityName(); } fromFragment.append(currentFragment); expression.identifier.accept(this); fromFragment.append(" AS "); fromFragment.append(currentFragment); fromFragment.append(','); currentFragment = null; } public void visit(GroupClause expression) { throw new RuntimeException("The method is not implemented"); } public void visit(JoinClause expression) { throw new RuntimeException("The method is not implemented"); } public void visit(OrderByClause expression) { throw new RuntimeException("The method is not implemented"); } public void visit(DeclareClause expression) { throw new RuntimeException("The method is not implemented"); } public void visit(WhereClause expression) { expression.getExpression().accept(this); whereFragment.append(currentFragment); } public void visit(SelectClause expression) { expression.getExpression().accept(this); selectFragment.append(currentFragment); selectFragment.append(','); currentFragment = null; } public void visit(QueryBody expression) { for (QueryBodyClause clause : expression.getClauses()) { clause.accept(this); } if (expression.hasSelectOrGroupClause()) { expression.getSelectOrGroupClause().accept(this); } if (expression.hasContinuation()) { expression.getContinuation().accept(this); } } public void visit(QueryContinuation expression) { throw new RuntimeException("The method is not implemented"); } public void visit(QueryExpression expression) { expression.getFrom().accept(this); expression.getQueryBody().accept(this); } public void visit(BinaryExpression expression) { expression.leftExpression.accept(this); String leftFragment = currentFragment; expression.rightExpression.accept(this); String rightFragment = currentFragment; String operatorFragment = ""; switch (expression.operator) { case AND: operatorFragment = "AND"; break; case OR: operatorFragment = "OR"; break; case EQUAL: operatorFragment = "="; break; case NOT_EQUAL: operatorFragment = "<>"; break; case GREATER_THAN: operatorFragment = ">"; break; case GREATER_THAN_OR_EQUAL: operatorFragment = ">="; break; case LESS_THAN: operatorFragment = "<"; break; case LESS_THAN_OR_EQUAL: operatorFragment = "<="; break; default: throw new UnsupportedOperationException("Operator not supported yet!"); } currentFragment = leftFragment + " " + operatorFragment + " " + rightFragment; } public void visit(TernaryExpression expression) { throw new RuntimeException("The method is not implemented"); } public void visit(UnaryExpression expression) { throw new RuntimeException("The method is not implemented"); } int parameterIndex = 1; private Map<Integer, Object> parameterMap = new HashMap<Integer, Object>(); public void visit(Constant expression) { currentFragment = "?" + parameterIndex; parameterMap.put(parameterIndex, expression.value); parameterIndex++; } public void visit(Identifier expression) { if (sourceNames.contains(expression.name)) { currentFragment = expression.name; } else { Constant asConstant = new Constant(expression.name, String.class); this.visit(asConstant); } } public void visit(MethodCall expression) { String methodName = expression.getIdentifier().name; if (methodName.startsWith("get")) { currentFragment = methodName.substring("get".length()); currentFragment = currentFragment.substring(0, 1).toLowerCase() + currentFragment.substring(1); } else if (methodName.startsWith("is")) { currentFragment = methodName.substring("is".length()); currentFragment = currentFragment.substring(0, 1).toLowerCase() + currentFragment.substring(1); } else { throw new RuntimeException("Cannot translate method " + methodName + " to property"); } currentFragment = "." + currentFragment; } public void visit(Indexer expression) { throw new RuntimeException("The method is not implemented"); } public void visit(Statement expression) { StringBuilder fragmentBuilder = new StringBuilder(); for (Expression e : expression.getExpressions()) { e.accept(this); fragmentBuilder.append(currentFragment); } currentFragment = fragmentBuilder.toString(); } public void visit(Parameter expression) { throw new RuntimeException("The method is not implemented"); } public void visit(NewExpression expression) { throw new RuntimeException("The method is not implemented"); } public void addSource(Identifier identifer, Queryable<?> source) { if (!(source instanceof QueryableEntity)) { throw new IllegalArgumentException("Only QueryableEnity can be used as a source"); } this.sources.put(identifer.name, (QueryableEntity) source); } public <T> T evaluate(Expression query) { query.accept(this); return (T) query(getJPQL(), parameterMap); } @SuppressWarnings("unchecked") private <T> T query(String jpql, Map<Integer, Object> parameterMap) { Query query = entityManager.createQuery(jpql); for (Integer parameterIndex : parameterMap.keySet()) { query.setParameter(parameterIndex, parameterMap.get(parameterIndex)); } return (T) query.getResultList(); } private String getJPQL() { return String.format("%s %s %s", getFragment("SELECT", selectFragment), getFragment("FROM", fromFragment), getFragment("WHERE", whereFragment)).trim(); } private String getFragment(String clause, StringBuilder fragmentBuilder) { if (fragmentBuilder.length() == 0) { return ""; } else { fragmentBuilder.insert(0, clause + " "); String fragment = fragmentBuilder.toString(); if (fragment.endsWith(",")) { fragment = fragment.substring(0, fragment.length() - 1); } return fragment; } } }