package org.apache.solr.search.federated;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.IndexableField;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.handler.component.ShardRequest;
import org.apache.solr.handler.component.ShardResponse;
import org.apache.solr.schema.CopyField;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
public class NumFoundSearchComponent extends SearchComponent {
public static final String COMPONENT_NAME = "djoin";
//FIXME: surely these are defined somewhere else?
public static final String SHARD_FIELD = "[shard]";
public static final String VERSION_FIELD = "_version_";
public static final String DJOIN_FIELD = "[" + COMPONENT_NAME + "]";
// initialisation parameters
public static final String INIT_JOIN_FIELD = "joinField";
public static final String INIT_IGNORE_CONVERSION_ERRORS = "ignoreConversionErrors";
// request parameters
public static final String DEBUG_PARAMETER = COMPONENT_NAME + ".debug";
private String joinField;
private boolean ignoreConversionErrors = false;
@Override
@SuppressWarnings("rawtypes")
public void init(NamedList args) {
super.init(args);
joinField = (String)args.get(INIT_JOIN_FIELD);
Boolean b = args.getBooleanArg(INIT_IGNORE_CONVERSION_ERRORS);
if (b != null) {
ignoreConversionErrors = b.booleanValue();
}
}
private static List<String> getFieldList(SolrParams params) {
return Arrays.asList(params.get(CommonParams.FL, "").split("\\s?,\\s?|\\s"));
}
@Override
public void prepare(ResponseBuilder rb) throws IOException {
System.out.println("===== PREPARE =====");
System.out.println("shards=" + rb.shards);
// only do this on aggregator
if (rb.shards == null) return;
List<String> fl = getFieldList(rb.req.getParams());
rb.req.getContext().put(COMPONENT_NAME + CommonParams.FL, fl);
Map<String, Long> numFounds = new HashMap<>();
rb.req.getContext().put(COMPONENT_NAME + "numFounds", numFounds);
Set<Object> joinIds = new HashSet<>();
rb.req.getContext().put(COMPONENT_NAME + "joinIds", joinIds);
}
@SuppressWarnings("unchecked")
private static boolean fieldListIncludes(ResponseBuilder rb, String fieldName) {
List<String> fl = (List<String>)rb.req.getContext().get(COMPONENT_NAME + CommonParams.FL);
return fl.contains(fieldName);
}
private static boolean isDebug(ResponseBuilder rb) {
return rb.req.getParams().getBool(DEBUG_PARAMETER, false);
}
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public void process(ResponseBuilder rb) throws IOException {
System.out.println("===== PROCESS =====");
System.out.println("shards=" + rb.shards);
// only do this on shards
if (rb.shards != null) return;
// output list of all group values
NamedList federated = new NamedList();
rb.rsp.getValues().add("federated_counts", federated);
NamedList counts = (NamedList)rb.rsp.getValues().get("facet_counts");
if (counts == null) return;
NamedList fields = (NamedList)counts.get("facet_fields");
if (fields == null) return;
for (int i = 0; i < fields.size(); ++i) {
String field = fields.getName(i);
List values = new ArrayList();
federated.add(field, values);
NamedList v = (NamedList)fields.get(field);
for (int j = 0; j < v.size(); ++j) {
values.add(v.getName(j));
}
}
// remove the unfederated results?
if (! isDebug(rb)) {
rb.rsp.getValues().remove("facet_counts");
}
}
/** not called on shards, i.e. only aggregator */
public int distributedProcess(ResponseBuilder rb) throws IOException {
System.out.println("===== DISTRIBUTED PROCESS =====");
System.out.println("stage=" + rb.stage);
return super.distributedProcess(rb);
}
@Override
public void modifyRequest(ResponseBuilder rb, SearchComponent who, ShardRequest sreq) {
System.out.println("===== MODIFY REQUEST =====");
System.out.println("who=" + who);
System.out.println("purpose=" + sreq.purpose);
if ((sreq.purpose & ShardRequest.PURPOSE_GET_FIELDS) > 0) {
if (fieldListIncludes(rb, DJOIN_FIELD)) {
Set<String> fl = new HashSet<>(getFieldList(sreq.params));
fl.add(SHARD_FIELD);
sreq.params.set(CommonParams.FL, String.join(",", fl));
}
// enable faceting on shards to get join ids
sreq.params.set("facet", true);
sreq.params.set("facet.field", joinField);
}
}
@Override
@SuppressWarnings({ "rawtypes", "unchecked" })
public void finishStage(ResponseBuilder rb) {
System.out.println("===== FINISH STAGE =====");
System.out.println("shards=" + rb.shards);
SolrCore core = rb.req.getCore();
IndexSchema schema = core.getLatestSchema();
String uniqueKeyField = schema.getUniqueKeyField().getName();
boolean includeShardId = fieldListIncludes(rb, DJOIN_FIELD);
boolean includeShard = fieldListIncludes(rb, SHARD_FIELD);
// only do this in final stage
if (rb.stage != 3000) return;
System.out.println("*** Federating results ***");
SolrDocumentList feds = new SolrDocumentList();
//rb.rsp.getValues().add("federated", feds);
/*NamedList results = (NamedList)grouped.get(joinField);
feds.setNumFound((Integer)results.get("matches"));
feds.setStart(rb.getQueryCommand().getOffset());
List<NamedList> groups = (List<NamedList>)results.get("groups");
for (NamedList group : groups) {
SolrDocumentList docs = (SolrDocumentList)group.get("doclist");
SolrDocument superDoc = new SolrDocument();
for (SolrDocument doc : docs) {
for (String fieldName : doc.getFieldNames()) {
if (fieldName.equals(SHARD_FIELD) && ! includeShard) {
continue;
}
SchemaField field = schema.getField(fieldName);
if (field == null || ! field.stored()) {
continue;
}
Object value = doc.getFieldValue(fieldName);
for (CopyField cp : schema.getCopyFieldsList(fieldName)) {
addConvertedFieldValue(superDoc, value, cp.getDestination());
}
if (fieldName.equals(uniqueKeyField)) {
// the [djoin] field value is [shard]:[id]:_version_
if (includeShardId) {
String shard = (String)doc.getFieldValue(SHARD_FIELD);
String version = doc.getFieldValue(VERSION_FIELD).toString();
addFieldValue(superDoc, shard + ":" + value + ":" + version, null);
}
} else {
addConvertedFieldValue(superDoc, value, field);
}
}
}
feds.add(superDoc);
}*/
// for now, just add the numFounds for each shard to the results...
/*NamedList details = new NamedList();
rb.rsp.getValues().add("federated", details);
Map<String, Long> numFounds = (Map<String, Long>)rb.req.getContext().get(COMPONENT_NAME + "numFounds");
details.add("numFounds", numFounds);*/
// ... and the size of joinIds
/*Set<Object> joinIds = (Set<Object>)rb.req.getContext().get(COMPONENT_NAME + "joinIds");
details.add("joinIds", joinIds);
SolrDocumentList docs = (SolrDocumentList)rb.rsp.getValues().get("response");
//docs.setNumFound((long)joinIds.size());
numFounds.put("total", (long)joinIds.size());*/
}
private void addConvertedFieldValue(SolrDocument superDoc, Object value, SchemaField field) {
try {
FieldType type = field.getType();
IndexableField indexable = type.createField(field, value, 1.0f);
addFieldValue(superDoc, type.toObject(indexable), field);
} catch (RuntimeException e) {
if (! ignoreConversionErrors) {
throw e;
}
}
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private boolean addFieldValue(SolrDocument superDoc, Object value, SchemaField field) {
if (value == null) return false;
String fieldName = field != null ? field.getName() : DJOIN_FIELD;
if (field == null || field.multiValued()) {
List list = (List)superDoc.getFieldValue(fieldName);
if (list == null) {
list = new ArrayList();
superDoc.setField(fieldName, list);
}
if (! list.contains(value)) {
list.add(value);
}
} else {
Object docValue = superDoc.get(fieldName);
if (docValue == null) {
superDoc.setField(fieldName, value);
} else if (! docValue.equals(value)) {
throw new RuntimeException("Field not multi-valued: " + fieldName);
}
}
return true;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
public void handleResponses(ResponseBuilder rb, ShardRequest req) {
System.out.println("===== HANDLE RESPONSES =====");
System.out.println("purpose=" + req.purpose);
System.out.println("Shards: " + (req.shards != null ? String.join(" ", req.shards) : "(null)"));
if ((req.purpose & ShardRequest.PURPOSE_GET_FIELDS) > 0) {
Map<String, Long> numFounds = (Map<String, Long>)rb.req.getContext().get(COMPONENT_NAME + "numFounds");
Set<Object> joinIds = (Set<Object>)rb.req.getContext().get(COMPONENT_NAME + "joinIds");
for (ShardResponse rsp : req.responses) {
NamedList response = rsp.getSolrResponse().getResponse();
SolrDocumentList results = (SolrDocumentList)response.get("response");
numFounds.put(rsp.getShard(), results.getNumFound());
NamedList counts = (NamedList)response.get("facet_counts");
if (counts != null) {
NamedList fields = (NamedList)counts.get("facet_fields");
NamedList values = (NamedList)fields.get(joinField);
for (int i = 0; i < values.size(); ++i) {
joinIds.add(values.getName(i));
}
}
}
}
}
@Override
public String getDescription() {
return "$description";
}
@Override
public String getSource() {
return "$source";
}
}