package org.apache.solr.search.federated; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.lucene.index.IndexableField; import org.apache.solr.common.SolrDocument; import org.apache.solr.common.SolrDocumentList; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.core.SolrCore; import org.apache.solr.handler.component.ResponseBuilder; import org.apache.solr.handler.component.SearchComponent; import org.apache.solr.handler.component.ShardRequest; import org.apache.solr.handler.component.ShardResponse; import org.apache.solr.schema.CopyField; import org.apache.solr.schema.FieldType; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; public class NumFoundSearchComponent extends SearchComponent { public static final String COMPONENT_NAME = "djoin"; //FIXME: surely these are defined somewhere else? public static final String SHARD_FIELD = "[shard]"; public static final String VERSION_FIELD = "_version_"; public static final String DJOIN_FIELD = "[" + COMPONENT_NAME + "]"; // initialisation parameters public static final String INIT_JOIN_FIELD = "joinField"; public static final String INIT_IGNORE_CONVERSION_ERRORS = "ignoreConversionErrors"; // request parameters public static final String DEBUG_PARAMETER = COMPONENT_NAME + ".debug"; private String joinField; private boolean ignoreConversionErrors = false; @Override @SuppressWarnings("rawtypes") public void init(NamedList args) { super.init(args); joinField = (String)args.get(INIT_JOIN_FIELD); Boolean b = args.getBooleanArg(INIT_IGNORE_CONVERSION_ERRORS); if (b != null) { ignoreConversionErrors = b.booleanValue(); } } private static List<String> getFieldList(SolrParams params) { return Arrays.asList(params.get(CommonParams.FL, "").split("\\s?,\\s?|\\s")); } @Override public void prepare(ResponseBuilder rb) throws IOException { System.out.println("===== PREPARE ====="); System.out.println("shards=" + rb.shards); // only do this on aggregator if (rb.shards == null) return; List<String> fl = getFieldList(rb.req.getParams()); rb.req.getContext().put(COMPONENT_NAME + CommonParams.FL, fl); Map<String, Long> numFounds = new HashMap<>(); rb.req.getContext().put(COMPONENT_NAME + "numFounds", numFounds); Set<Object> joinIds = new HashSet<>(); rb.req.getContext().put(COMPONENT_NAME + "joinIds", joinIds); } @SuppressWarnings("unchecked") private static boolean fieldListIncludes(ResponseBuilder rb, String fieldName) { List<String> fl = (List<String>)rb.req.getContext().get(COMPONENT_NAME + CommonParams.FL); return fl.contains(fieldName); } private static boolean isDebug(ResponseBuilder rb) { return rb.req.getParams().getBool(DEBUG_PARAMETER, false); } @Override @SuppressWarnings({ "unchecked", "rawtypes" }) public void process(ResponseBuilder rb) throws IOException { System.out.println("===== PROCESS ====="); System.out.println("shards=" + rb.shards); // only do this on shards if (rb.shards != null) return; // output list of all group values NamedList federated = new NamedList(); rb.rsp.getValues().add("federated_counts", federated); NamedList counts = (NamedList)rb.rsp.getValues().get("facet_counts"); if (counts == null) return; NamedList fields = (NamedList)counts.get("facet_fields"); if (fields == null) return; for (int i = 0; i < fields.size(); ++i) { String field = fields.getName(i); List values = new ArrayList(); federated.add(field, values); NamedList v = (NamedList)fields.get(field); for (int j = 0; j < v.size(); ++j) { values.add(v.getName(j)); } } // remove the unfederated results? if (! isDebug(rb)) { rb.rsp.getValues().remove("facet_counts"); } } /** not called on shards, i.e. only aggregator */ public int distributedProcess(ResponseBuilder rb) throws IOException { System.out.println("===== DISTRIBUTED PROCESS ====="); System.out.println("stage=" + rb.stage); return super.distributedProcess(rb); } @Override public void modifyRequest(ResponseBuilder rb, SearchComponent who, ShardRequest sreq) { System.out.println("===== MODIFY REQUEST ====="); System.out.println("who=" + who); System.out.println("purpose=" + sreq.purpose); if ((sreq.purpose & ShardRequest.PURPOSE_GET_FIELDS) > 0) { if (fieldListIncludes(rb, DJOIN_FIELD)) { Set<String> fl = new HashSet<>(getFieldList(sreq.params)); fl.add(SHARD_FIELD); sreq.params.set(CommonParams.FL, String.join(",", fl)); } // enable faceting on shards to get join ids sreq.params.set("facet", true); sreq.params.set("facet.field", joinField); } } @Override @SuppressWarnings({ "rawtypes", "unchecked" }) public void finishStage(ResponseBuilder rb) { System.out.println("===== FINISH STAGE ====="); System.out.println("shards=" + rb.shards); SolrCore core = rb.req.getCore(); IndexSchema schema = core.getLatestSchema(); String uniqueKeyField = schema.getUniqueKeyField().getName(); boolean includeShardId = fieldListIncludes(rb, DJOIN_FIELD); boolean includeShard = fieldListIncludes(rb, SHARD_FIELD); // only do this in final stage if (rb.stage != 3000) return; System.out.println("*** Federating results ***"); SolrDocumentList feds = new SolrDocumentList(); //rb.rsp.getValues().add("federated", feds); /*NamedList results = (NamedList)grouped.get(joinField); feds.setNumFound((Integer)results.get("matches")); feds.setStart(rb.getQueryCommand().getOffset()); List<NamedList> groups = (List<NamedList>)results.get("groups"); for (NamedList group : groups) { SolrDocumentList docs = (SolrDocumentList)group.get("doclist"); SolrDocument superDoc = new SolrDocument(); for (SolrDocument doc : docs) { for (String fieldName : doc.getFieldNames()) { if (fieldName.equals(SHARD_FIELD) && ! includeShard) { continue; } SchemaField field = schema.getField(fieldName); if (field == null || ! field.stored()) { continue; } Object value = doc.getFieldValue(fieldName); for (CopyField cp : schema.getCopyFieldsList(fieldName)) { addConvertedFieldValue(superDoc, value, cp.getDestination()); } if (fieldName.equals(uniqueKeyField)) { // the [djoin] field value is [shard]:[id]:_version_ if (includeShardId) { String shard = (String)doc.getFieldValue(SHARD_FIELD); String version = doc.getFieldValue(VERSION_FIELD).toString(); addFieldValue(superDoc, shard + ":" + value + ":" + version, null); } } else { addConvertedFieldValue(superDoc, value, field); } } } feds.add(superDoc); }*/ // for now, just add the numFounds for each shard to the results... /*NamedList details = new NamedList(); rb.rsp.getValues().add("federated", details); Map<String, Long> numFounds = (Map<String, Long>)rb.req.getContext().get(COMPONENT_NAME + "numFounds"); details.add("numFounds", numFounds);*/ // ... and the size of joinIds /*Set<Object> joinIds = (Set<Object>)rb.req.getContext().get(COMPONENT_NAME + "joinIds"); details.add("joinIds", joinIds); SolrDocumentList docs = (SolrDocumentList)rb.rsp.getValues().get("response"); //docs.setNumFound((long)joinIds.size()); numFounds.put("total", (long)joinIds.size());*/ } private void addConvertedFieldValue(SolrDocument superDoc, Object value, SchemaField field) { try { FieldType type = field.getType(); IndexableField indexable = type.createField(field, value, 1.0f); addFieldValue(superDoc, type.toObject(indexable), field); } catch (RuntimeException e) { if (! ignoreConversionErrors) { throw e; } } } @SuppressWarnings({ "unchecked", "rawtypes" }) private boolean addFieldValue(SolrDocument superDoc, Object value, SchemaField field) { if (value == null) return false; String fieldName = field != null ? field.getName() : DJOIN_FIELD; if (field == null || field.multiValued()) { List list = (List)superDoc.getFieldValue(fieldName); if (list == null) { list = new ArrayList(); superDoc.setField(fieldName, list); } if (! list.contains(value)) { list.add(value); } } else { Object docValue = superDoc.get(fieldName); if (docValue == null) { superDoc.setField(fieldName, value); } else if (! docValue.equals(value)) { throw new RuntimeException("Field not multi-valued: " + fieldName); } } return true; } @SuppressWarnings({ "rawtypes", "unchecked" }) @Override public void handleResponses(ResponseBuilder rb, ShardRequest req) { System.out.println("===== HANDLE RESPONSES ====="); System.out.println("purpose=" + req.purpose); System.out.println("Shards: " + (req.shards != null ? String.join(" ", req.shards) : "(null)")); if ((req.purpose & ShardRequest.PURPOSE_GET_FIELDS) > 0) { Map<String, Long> numFounds = (Map<String, Long>)rb.req.getContext().get(COMPONENT_NAME + "numFounds"); Set<Object> joinIds = (Set<Object>)rb.req.getContext().get(COMPONENT_NAME + "joinIds"); for (ShardResponse rsp : req.responses) { NamedList response = rsp.getSolrResponse().getResponse(); SolrDocumentList results = (SolrDocumentList)response.get("response"); numFounds.put(rsp.getShard(), results.getNumFound()); NamedList counts = (NamedList)response.get("facet_counts"); if (counts != null) { NamedList fields = (NamedList)counts.get("facet_fields"); NamedList values = (NamedList)fields.get(joinField); for (int i = 0; i < values.size(); ++i) { joinIds.add(values.getName(i)); } } } } } @Override public String getDescription() { return "$description"; } @Override public String getSource() { return "$source"; } }