package org.apache.solr.handler.clustering.carrot2; /** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.io.IOException; import java.util.*; import org.apache.commons.lang.StringUtils; import org.apache.lucene.document.FieldSelector; import org.apache.lucene.document.SetBasedFieldSelector; import org.apache.lucene.search.Query; import org.apache.solr.common.params.HighlightParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.SimpleOrderedMap; import org.apache.solr.common.SolrException; import org.apache.solr.core.SolrCore; import org.apache.solr.handler.clustering.SearchClusteringEngine; import org.apache.solr.highlight.SolrHighlighter; import org.apache.solr.request.LocalSolrQueryRequest; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.*; import org.apache.solr.util.RefCounted; import org.carrot2.core.*; import org.carrot2.core.attribute.AttributeNames; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.Sets; /** * Search results clustering engine based on Carrot2 clustering algorithms. * <p/> * Output from this class is subject to change. * * @link http://project.carrot2.org */ @SuppressWarnings("unchecked") public class CarrotClusteringEngine extends SearchClusteringEngine { private transient static Logger log = LoggerFactory .getLogger(CarrotClusteringEngine.class); /** * Carrot2 controller that manages instances of clustering algorithms */ private Controller controller = ControllerFactory.createPooling(); private Class<? extends IClusteringAlgorithm> clusteringAlgorithmClass; private String idFieldName; public Object cluster(Query query, DocList docList, SolrQueryRequest sreq) { try { // Prepare attributes for Carrot2 clustering call Map<String, Object> attributes = new HashMap<String, Object>(); List<Document> documents = getDocuments(docList, query, sreq); attributes.put(AttributeNames.DOCUMENTS, documents); attributes.put(AttributeNames.QUERY, query.toString()); // Pass extra overriding attributes from the request, if any extractCarrotAttributes(sreq.getParams(), attributes); // Perform clustering and convert to named list return clustersToNamedList(controller.process(attributes, clusteringAlgorithmClass).getClusters(), sreq.getParams()); } catch (Exception e) { log.error("Carrot2 clustering failed", e); throw new RuntimeException(e); } } @Override public String init(NamedList config, final SolrCore core) { String result = super.init(config, core); SolrParams initParams = SolrParams.toSolrParams(config); // Initialize Carrot2 controller. Pass initialization attributes, if any. HashMap<String, Object> initAttributes = new HashMap<String, Object>(); extractCarrotAttributes(initParams, initAttributes); // Customize the language model factory. The implementation we provide here // is included in the code base of Solr, so that it's possible to refactor // the Lucene APIs the factory relies on if needed. initAttributes.put("PreprocessingPipeline.languageModelFactory", new LuceneLanguageModelFactory()); this.controller.init(initAttributes); this.idFieldName = core.getSchema().getUniqueKeyField().getName(); // Make sure the requested Carrot2 clustering algorithm class is available String carrotAlgorithmClassName = initParams.get(CarrotParams.ALGORITHM); Class<?> algorithmClass = core.getResourceLoader().findClass(carrotAlgorithmClassName); if (!IClusteringAlgorithm.class.isAssignableFrom(algorithmClass)) { throw new IllegalArgumentException("Class provided as " + CarrotParams.ALGORITHM + " must implement " + IClusteringAlgorithm.class.getName()); } this.clusteringAlgorithmClass = (Class<? extends IClusteringAlgorithm>) algorithmClass; return result; } /** * Prepares Carrot2 documents for clustering. */ private List<Document> getDocuments(DocList docList, Query query, final SolrQueryRequest sreq) throws IOException { SolrHighlighter highlighter = null; SolrParams solrParams = sreq.getParams(); SolrCore core = sreq.getCore(); // Names of fields to deliver content for clustering String urlField = solrParams.get(CarrotParams.URL_FIELD_NAME, "url"); String titleField = solrParams.get(CarrotParams.TITLE_FIELD_NAME, "title"); String snippetField = solrParams.get(CarrotParams.SNIPPET_FIELD_NAME, titleField); if (StringUtils.isBlank(snippetField)) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, CarrotParams.SNIPPET_FIELD_NAME + " must not be blank."); } Set<String> fieldsToLoad = Sets.newHashSet(urlField, titleField, snippetField, idFieldName); // Get the documents DocIterator docsIter = docList.iterator(); boolean produceSummary = solrParams.getBool(CarrotParams.PRODUCE_SUMMARY, false); SolrQueryRequest req = null; String[] snippetFieldAry = null; if (produceSummary == true) { highlighter = core.getHighlighter(); if (highlighter != null){ Map args = new HashMap(); snippetFieldAry = new String[]{snippetField}; args.put(HighlightParams.FIELDS, snippetFieldAry); args.put(HighlightParams.HIGHLIGHT, "true"); args.put(HighlightParams.SIMPLE_PRE, ""); //we don't care about actually highlighting the area args.put(HighlightParams.SIMPLE_POST, ""); args.put(HighlightParams.FRAGSIZE, solrParams.getInt(CarrotParams.SUMMARY_FRAGSIZE, solrParams.getInt(HighlightParams.FRAGSIZE, 100))); req = new LocalSolrQueryRequest(core, query.toString(), "", 0, 1, args) { @Override public SolrIndexSearcher getSearcher() { return sreq.getSearcher(); } }; } else { log.warn("No highlighter configured, cannot produce summary"); produceSummary = false; } } SolrIndexSearcher searcher = sreq.getSearcher(); List<Document> result = new ArrayList<Document>(docList.size()); float[] scores = {1.0f}; int[] docsHolder = new int[1]; Query theQuery = query; while (docsIter.hasNext()) { Integer id = docsIter.next(); org.apache.lucene.document.Document doc = searcher.doc(id, fieldsToLoad); String snippet = getValue(doc, snippetField); if (produceSummary == true) { docsHolder[0] = id.intValue(); DocList docAsList = new DocSlice(0, 1, docsHolder, scores, 1, 1.0f); NamedList highlights = highlighter.doHighlighting(docAsList, theQuery, req, snippetFieldAry); if (highlights != null && highlights.size() == 1) {//should only be one value given our setup //should only be one document with one field NamedList tmp = (NamedList) highlights.getVal(0); String [] highlt = (String[]) tmp.get(snippetField); if (highlt != null && highlt.length == 1) { snippet = highlt[0]; } } } Document carrotDocument = new Document(getValue(doc, titleField), snippet, doc.get(urlField)); carrotDocument.setField("solrId", doc.get(idFieldName)); result.add(carrotDocument); } return result; } protected String getValue(org.apache.lucene.document.Document doc, String field) { StringBuilder result = new StringBuilder(); String[] vals = doc.getValues(field); for (int i = 0; i < vals.length; i++) { // Join multiple values with a period so that Carrot2 does not pick up // phrases that cross field value boundaries (in most cases it would // create useless phrases). result.append(vals[i]).append(" . "); } return result.toString().trim(); } private List clustersToNamedList(List<Cluster> carrotClusters, SolrParams solrParams) { List result = new ArrayList(); clustersToNamedList(carrotClusters, result, solrParams.getBool( CarrotParams.OUTPUT_SUB_CLUSTERS, true), solrParams.getInt( CarrotParams.NUM_DESCRIPTIONS, Integer.MAX_VALUE)); return result; } private void clustersToNamedList(List<Cluster> outputClusters, List parent, boolean outputSubClusters, int maxLabels) { for (Cluster outCluster : outputClusters) { NamedList cluster = new SimpleOrderedMap(); parent.add(cluster); List<String> labels = outCluster.getPhrases(); if (labels.size() > maxLabels) labels = labels.subList(0, maxLabels); cluster.add("labels", labels); List<Document> docs = outputSubClusters ? outCluster.getDocuments() : outCluster.getAllDocuments(); List docList = new ArrayList(); cluster.add("docs", docList); for (Document doc : docs) { docList.add(doc.getField("solrId")); } if (outputSubClusters) { List subclusters = new ArrayList(); cluster.add("clusters", subclusters); clustersToNamedList(outCluster.getSubclusters(), subclusters, outputSubClusters, maxLabels); } } } /** * Extracts parameters that can possibly match some attributes of Carrot2 algorithms. */ private void extractCarrotAttributes(SolrParams solrParams, Map<String, Object> attributes) { // Extract all non-predefined parameters. This way, we'll be able to set all // parameters of Carrot2 algorithms without defining their names as constants. for (Iterator<String> paramNames = solrParams.getParameterNamesIterator(); paramNames .hasNext();) { String paramName = paramNames.next(); if (!CarrotParams.CARROT_PARAM_NAMES.contains(paramName)) { attributes.put(paramName, solrParams.get(paramName)); } } } }