package com.ctp.cdi.query.builder.postprocessor;
import javax.persistence.Query;
import org.jboss.solder.logging.Logger;
import com.ctp.cdi.query.handler.JpaQueryPostProcessor;
import com.ctp.cdi.query.handler.CdiQueryInvocationContext;
import com.ctp.cdi.query.param.Parameters;
import com.ctp.cdi.query.util.QueryUtils;
import com.ctp.cdi.query.util.jpa.QueryStringExtractorFactory;
public class CountQueryPostProcessor implements JpaQueryPostProcessor {
private final Logger log = Logger.getLogger(CountQueryPostProcessor.class);
private final QueryStringExtractorFactory factory = new QueryStringExtractorFactory();
@Override
public Query postProcess(CdiQueryInvocationContext context, Query query) {
String queryString = getQueryString(context, query);
QueryExtraction extract = new QueryExtraction(queryString);
String count = extract.rewriteToCount();
log.debugv("Rewrote query {0} to {1}", queryString, count);
Query result = context.getEntityManager().createQuery(count);
Parameters params = context.getParams();
params.applyTo(result);
return result;
}
private String getQueryString(CdiQueryInvocationContext context, Query query) {
if (QueryUtils.isNotEmpty(context.getQueryString())) {
return context.getQueryString();
}
return factory.select(query).extractFrom(query);
}
private static class QueryExtraction {
private String select;
private String from;
private String where;
private String entityName;
private final String query;
public QueryExtraction(String query) {
this.query = query;
}
public String rewriteToCount() {
splitQuery();
extractEntityName();
return rewrite();
}
private String rewrite() {
return "select count(" + (select != null ? select : entityName) + ") " + from + where;
}
private void extractEntityName() {
String[] split = from.split(" ");
if (split.length > 1) {
entityName = split[split.length - 1];
} else {
entityName = "*";
}
}
private void splitQuery() {
String lower = query.toLowerCase();
int selectIndex = lower.indexOf("select");
int fromIndex = lower.indexOf("from");
int whereIndex = lower.indexOf("where");
if (selectIndex >= 0) {
select = query.substring("select".length(), fromIndex);
}
if (whereIndex >= 0) {
from = query.substring(fromIndex, whereIndex);
where = query.substring(whereIndex);
} else {
from = query.substring(fromIndex);
}
}
}
}