package org.molgenis.data.rsql;
import cz.jirutka.rsql.parser.ast.*;
import org.molgenis.data.Entity;
import org.molgenis.data.MolgenisQueryException;
import org.molgenis.data.Query;
import org.molgenis.data.UnknownAttributeException;
import org.molgenis.data.meta.model.Attribute;
import org.molgenis.data.meta.model.EntityType;
import org.molgenis.data.support.QueryImpl;
import java.util.Iterator;
import java.util.List;
import static java.util.stream.Collectors.toList;
/**
* RSQLVisitor implementation that creates {@link Query} objects for an RSQL tree.
*
* @see <a href="https://github.com/jirutka/rsql-parser">https://github.com/jirutka/rsql-parser</a>
*/
public class MolgenisRSQLVisitor extends NoArgRSQLVisitorAdapter<Query<Entity>>
{
private final QueryImpl<Entity> q = new QueryImpl<>();
private final EntityType entityType;
private final RSQLValueParser rsqlValueParser = new RSQLValueParser();
public MolgenisRSQLVisitor(EntityType entityType)
{
this.entityType = entityType;
}
@Override
public Query<Entity> visit(AndNode node)
{
q.nest(); // TODO only nest if more than one child
for (Iterator<Node> it = node.iterator(); it.hasNext(); )
{
Node child = it.next();
child.accept(this);
if (it.hasNext())
{
q.and();
}
}
q.unnest();
return q;
}
@Override
public Query<Entity> visit(OrNode node)
{
q.nest(); // TODO only nest if more than one child
for (Iterator<Node> it = node.iterator(); it.hasNext(); )
{
Node child = it.next();
child.accept(this);
if (it.hasNext())
{
q.or();
}
}
q.unnest();
return q;
}
@Override
public Query<Entity> visit(ComparisonNode node)
{
String attrName = node.getSelector();
String symbol = node.getOperator().getSymbol();
List<String> values = node.getArguments();
switch (symbol)
{
case "=notlike=":
String notLikeValue = values.get(0);
q.not().like(attrName, notLikeValue);
break;
case "=q=":
String searchValue = values.get(0);
if (attrName.equals("*"))
{
q.search(searchValue);
}
else
{
q.search(attrName, searchValue);
}
break;
case "==":
Object eqValue = rsqlValueParser.parse(values.get(0), getAttribute(node));
q.eq(attrName, eqValue);
break;
case "=in=":
Attribute inAttr = getAttribute(node);
q.in(attrName, values.stream().map(value -> rsqlValueParser.parse(value, inAttr)).collect(toList()));
break;
case "=lt=":
case "<":
Attribute ltAttr = getAttribute(node);
validateNumericOrDate(ltAttr);
Object ltValue = rsqlValueParser.parse(values.get(0), ltAttr);
q.lt(attrName, ltValue);
break;
case "=le=":
case "<=":
Attribute leAttr = getAttribute(node);
validateNumericOrDate(leAttr);
Object leValue = rsqlValueParser.parse(values.get(0), leAttr);
q.le(attrName, leValue);
break;
case "=gt=":
case ">":
Attribute gtAttr = getAttribute(node);
validateNumericOrDate(gtAttr);
Object gtValue = rsqlValueParser.parse(values.get(0), gtAttr);
q.gt(attrName, gtValue);
break;
case "=ge=":
case ">=":
Attribute geAttr = getAttribute(node);
validateNumericOrDate(geAttr);
Object geValue = rsqlValueParser.parse(values.get(0), geAttr);
q.ge(attrName, geValue);
break;
case "=rng=":
Attribute rngAttr = getAttribute(node);
validateNumericOrDate(rngAttr);
Object fromValue = values.get(0) != null ? rsqlValueParser.parse(values.get(0), rngAttr) : null;
Object toValue = values.get(1) != null ? rsqlValueParser.parse(values.get(1), rngAttr) : null;
q.rng(attrName, fromValue, toValue);
break;
case "=like=":
String likeValue = values.get(0);
q.like(attrName, likeValue);
break;
case "!=":
Object notEqValue = rsqlValueParser.parse(values.get(0), getAttribute(node));
q.not().eq(attrName, notEqValue);
break;
case "=should=":
throw new MolgenisQueryException("Unsupported RSQL query operator [" + symbol + "]");
case "=dismax=":
throw new MolgenisQueryException("Unsupported RSQL query operator [" + symbol + "]");
case "=fuzzy=":
throw new MolgenisQueryException("Unsupported RSQL query operator [" + symbol + "]");
default:
throw new MolgenisQueryException("Unknown RSQL query operator [" + symbol + "]");
}
return q;
}
private void validateNumericOrDate(Attribute attr)
{
switch (attr.getDataType())
{
case DATE:
case DATE_TIME:
case DECIMAL:
case INT:
case LONG:
break;
// $CASES-OMITTED$
default:
throw new IllegalArgumentException(
"Can't perform operator '\" + symbol + \"' on attribute '\"" + attr.getName() + "\"");
}
}
private Attribute getAttribute(ComparisonNode node)
{
String attrName = node.getSelector();
String[] attrTokens = attrName.split("\\.");
Attribute attr = entityType.getAttribute(attrTokens[0]);
if (attr == null)
{
throw new UnknownAttributeException("Unknown attribute [" + attrName + "]");
}
EntityType entityTypeAtDepth;
for (int i = 1; i < attrTokens.length; ++i)
{
entityTypeAtDepth = attr.getRefEntity();
attr = entityTypeAtDepth.getAttribute(attrTokens[i]);
if (attr == null)
{
throw new UnknownAttributeException("Unknown attribute [" + attrName + "]");
}
}
return attr;
}
}