package org.vertexium.elasticsearch2; import org.elasticsearch.client.Client; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.vertexium.*; import org.vertexium.elasticsearch2.score.ScoringStrategy; import org.vertexium.query.VertexQuery; import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.stream.Stream; import static org.vertexium.util.StreamUtils.stream; public class ElasticsearchSearchVertexQuery extends ElasticsearchSearchQueryBase implements VertexQuery { private final Vertex sourceVertex; private Direction direction = Direction.BOTH; private String otherVertexId; public ElasticsearchSearchVertexQuery( Client client, Graph graph, Vertex sourceVertex, String queryString, ScoringStrategy scoringStrategy, IndexSelectionStrategy indexSelectionStrategy, int pageSize, Authorizations authorizations ) { super(client, graph, queryString, scoringStrategy, indexSelectionStrategy, pageSize, authorizations); this.sourceVertex = sourceVertex; } @Override protected List<QueryBuilder> getFilters(EnumSet<ElasticsearchDocumentType> elementTypes) { List<QueryBuilder> filters = super.getFilters(elementTypes); List<QueryBuilder> relatedFilters = new ArrayList<>(); if (elementTypes.contains(ElasticsearchDocumentType.VERTEX) || elementTypes.contains(ElasticsearchDocumentType.VERTEX_EXTENDED_DATA)) { relatedFilters.add(getVertexFilter(elementTypes)); } if (elementTypes.contains(ElasticsearchDocumentType.EDGE) || elementTypes.contains(ElasticsearchDocumentType.EDGE_EXTENDED_DATA)) { relatedFilters.add(getEdgeFilter()); } filters.add(orFilters(relatedFilters)); return filters; } private QueryBuilder getEdgeFilter() { switch (direction) { case BOTH: QueryBuilder inVertexIdFilter = getDirectionInEdgeFilter(); QueryBuilder outVertexIdFilter = getDirectionOutEdgeFilter(); return QueryBuilders.orQuery(inVertexIdFilter, outVertexIdFilter); case OUT: return getDirectionOutEdgeFilter(); case IN: return getDirectionInEdgeFilter(); default: throw new VertexiumException("unexpected direction: " + direction); } } private QueryBuilder getDirectionInEdgeFilter() { QueryBuilder outVertexIdFilter = QueryBuilders.termQuery(Elasticsearch2SearchIndex.IN_VERTEX_ID_FIELD_NAME, sourceVertex.getId()); if (otherVertexId != null) { QueryBuilder inVertexIdFilter = QueryBuilders.termQuery(Elasticsearch2SearchIndex.OUT_VERTEX_ID_FIELD_NAME, otherVertexId); return QueryBuilders.andQuery(outVertexIdFilter, inVertexIdFilter); } return outVertexIdFilter; } private QueryBuilder getDirectionOutEdgeFilter() { QueryBuilder outVertexIdFilter = QueryBuilders.termQuery(Elasticsearch2SearchIndex.OUT_VERTEX_ID_FIELD_NAME, sourceVertex.getId()); if (otherVertexId != null) { QueryBuilder inVertexIdFilter = QueryBuilders.termQuery(Elasticsearch2SearchIndex.IN_VERTEX_ID_FIELD_NAME, otherVertexId); return QueryBuilders.andQuery(outVertexIdFilter, inVertexIdFilter); } return outVertexIdFilter; } private QueryBuilder getVertexFilter(EnumSet<ElasticsearchDocumentType> elementTypes) { List<QueryBuilder> filters = new ArrayList<>(); List<String> edgeLabels = getParameters().getEdgeLabels(); String[] edgeLabelsArray = edgeLabels == null || edgeLabels.size() == 0 ? null : edgeLabels.toArray(new String[edgeLabels.size()]); Stream<EdgeInfo> edgeInfos = stream(sourceVertex.getEdgeInfos( direction, edgeLabelsArray, getParameters().getAuthorizations() )); if (otherVertexId != null) { edgeInfos = edgeInfos.filter(ei -> ei.getVertexId().equals(otherVertexId)); } String[] ids = edgeInfos.map(EdgeInfo::getVertexId).toArray(String[]::new); if (elementTypes.contains(ElasticsearchDocumentType.VERTEX)) { filters.add(QueryBuilders.idsQuery().ids(ids)); } if (elementTypes.contains(ElasticsearchDocumentType.VERTEX_EXTENDED_DATA)) { for (String vertexId : ids) { filters.add(QueryBuilders.andQuery( QueryBuilders.termQuery(Elasticsearch2SearchIndex.ELEMENT_TYPE_FIELD_NAME, ElasticsearchDocumentType.VERTEX_EXTENDED_DATA.getKey()), QueryBuilders.termQuery(Elasticsearch2SearchIndex.EXTENDED_DATA_ELEMENT_ID_FIELD_NAME, vertexId) )); } } return orFilters(filters); } private QueryBuilder orFilters(List<QueryBuilder> filters) { if (filters.size() == 1) { return filters.get(0); } else { return QueryBuilders.orQuery(filters.toArray(new QueryBuilder[filters.size()])); } } @Override public VertexQuery hasDirection(Direction direction) { this.direction = direction; return this; } @Override public VertexQuery hasOtherVertexId(String otherVertexId) { this.otherVertexId = otherVertexId; return this; } }