/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.search.concordance; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map.Entry; import java.util.Set; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.corpus.stats.IDFCalc; import org.apache.lucene.corpus.stats.TermIDF; import org.apache.lucene.index.IndexReader; import org.apache.lucene.queryparser.classic.ParseException; import org.apache.lucene.search.Filter; import org.apache.lucene.search.Query; import org.apache.lucene.search.concordance.charoffsets.TargetTokenNotFoundException; import org.apache.lucene.search.concordance.classic.ConcordanceSortOrder; import org.apache.lucene.search.concordance.classic.DocIdBuilder; import org.apache.lucene.search.concordance.classic.DocMetadataExtractor; import org.apache.lucene.search.concordance.classic.impl.FieldBasedDocIdBuilder; import org.apache.lucene.search.concordance.classic.impl.SimpleDocMetadataExtractor; import org.apache.lucene.search.concordance.windowvisitor.ConcordanceArrayWindowSearcher; import org.apache.lucene.search.concordance.windowvisitor.CooccurVisitor; import org.apache.lucene.search.concordance.windowvisitor.Grammer; import org.apache.lucene.search.concordance.windowvisitor.WGrammer; import org.apache.solr.cloud.RequestThreads; import org.apache.solr.cloud.RequestWorker; import org.apache.solr.cloud.ZkController; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.ModifiableSolrParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.SimpleOrderedMap; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; import org.apache.solr.search.QParser; import org.apache.solr.search.SolrIndexSearcher; /** * <requestHandler name="/kwCooccur" class="org.apache.solr.search.concordance.KeywordCooccurRankHandler"> * <lst name="defaults"> * <str name="echoParams">explicit</str> * <str name="defType">spanquery</str> * <str name="f">content_txt</str> * <str name="df">content_txt</str> * <str name="wt">xml</str> * <str name="minNGram">1</str> * <str name="maxNGram">2</str> * <str name="minTF">3</str> * <str name="numResults">50</str> * <p> * <!-- More fields * <str name="maxWindows">500</str> * <str name="debug">false</str> * <str name="fl">metadata field1,metadata field2,metadata field3</str> * <str name="targetOverlaps">true</str> * <str name="contentDisplaySize">42</str> * <str name="targetDisplaySize">42</str> * <str name="tokensAfter">42</str> * <str name="tokensBefore">42</str> * <str name="sortOrder">TARGET_PRE</str> * --> * </lst> * <lst name="invariants"> * <str name="prop1">value1</str> * <int name="prop2">2</int> * <!-- ... more config items here ... --> * </lst> * </requestHandler> * * @author JRROBINSON */ public class KeywordCooccurRankHandler extends SolrConcordanceBase { public static final String DefaultName = "/kwCo"; public static final String NODE = "contextKeywords"; /** * Max number of request threads to spawn. Since this service wasn't intended to return * ALL possible results, it seems reasonable to cap this at something */ public final static int MAX_THREADS = 25; ; static public RequestThreads<CooccurConfig> initRequestPump(List<String> shards, SolrQueryRequest req) { return initRequestPump(shards, req, MAX_THREADS); } static public RequestThreads<CooccurConfig> initRequestPump(List<String> shards, SolrQueryRequest req, int maxThreads) { SolrParams params = req.getParams(); String field = SolrConcordanceBase.getField(params, req.getSchema().getDefaultSearchFieldName()); String q = params.get(CommonParams.Q); CooccurConfig config = configureParams(field, params); /**/ RequestThreads<CooccurConfig> threads = RequestThreads.<CooccurConfig>newFixedThreadPool(Math.min(shards.size(), maxThreads)) .setMetadata(config); String handler = getHandlerName(req, DefaultName, KeywordCooccurRankHandler.class); int partial = Math.round(config.getMaxWindows() / (float) shards.size()); ModifiableSolrParams p = getWorkerParams(field, q, params, partial); int i = 0; for (String node : shards) { if (i++ > maxThreads) break; //could be https, no? String url = "http://" + node; RequestWorker worker = new RequestWorker(url, handler, p).setName(node); threads.addExecute(worker); } threads.seal(); //disallow future requests (& execute return threads; } public static NamedList doLocalSearch(SolrQueryRequest req) throws Exception { return doLocalSearch(null, req); } //xx public static NamedList doLocalSearch(Query filter, SolrQueryRequest req) throws Exception { SolrParams params = req.getParams(); String field = getField(params); String fl = params.get(CommonParams.FL); DocMetadataExtractor metadataExtractor = (fl != null && fl.length() > 0) ? new SimpleDocMetadataExtractor(fl.split(",")) : new SimpleDocMetadataExtractor(); CooccurConfig config = configureParams(field, params); IndexSchema schema = req.getSchema(); SchemaField sf = schema.getField(field); Analyzer analyzer = sf.getType().getIndexAnalyzer(); Filter queryFilter = getFilterQuery(req); String q = params.get(CommonParams.Q); Query query = QParser.getParser(q, null, req).parse(); String solrUniqueKeyField = req.getSchema().getUniqueKeyField().getName(); SolrIndexSearcher solr = req.getSearcher(); IndexReader reader = solr.getIndexReader(); boolean allowDuplicates = false; boolean allowFieldSeparators = false; Grammer grammer = new WGrammer(config.getMinNGram(), config.getMaxNGram(), allowFieldSeparators); IDFCalc idfCalc = new IDFCalc(reader); CooccurVisitor visitor = new CooccurVisitor(field, config.getTokensBefore(), config.getTokensAfter() , grammer , idfCalc , config.getMaxWindows() , allowDuplicates); visitor.setMinTermFreq(config.getMinTermFreq()); try { ConcordanceArrayWindowSearcher searcher = new ConcordanceArrayWindowSearcher(); System.out.println("UNIQUE KEY FIELD: " + solrUniqueKeyField); DocIdBuilder docIdBuilder = new FieldBasedDocIdBuilder(solrUniqueKeyField); System.out.println("QUERY: " + query.toString()); searcher.search(reader, field, query, queryFilter, analyzer, visitor, docIdBuilder); } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (TargetTokenNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } List<TermIDF> overallResults = visitor.getResults(); NamedList results = toNamedList(overallResults); //needed for cloud computations, merging cores results.add("collectionSize", reader.numDocs()); results.add("numDocsVisited", visitor.getNumDocsVisited()); results.add("numWindowsVisited", visitor.getNumWindowsVisited()); results.add("numResults", overallResults.size()); results.add("minTF", visitor.getMinTermFreq()); return results; } public static ModifiableSolrParams getWorkerParams(String field, String q, SolrParams parent, Integer maxWindows) { ModifiableSolrParams params = new ModifiableSolrParams(); params.set("f", field); params.set("q", q); params.set("maxWindows", maxWindows); params.set("lq", true); //flag to disallow recursive zoo queries params.set("rows", 0); setParam("fq", params, parent); setParam("anType", params, parent); setParam("numResults", params, parent); setParam("minNGram", params, parent); setParam("maxNGram", params, parent); setParam("minTF", params, parent); setParam("minDF", params, parent); setParam("echoParams", params, parent); setParam("defType", params, parent); setParam("wt", params, parent); setParam("debug", params, parent); setParam("fl", params, parent); setParam("targetOverlaps", params, parent); setParam("contentDisplaySize", params, parent); setParam("targetDisplaySize", params, parent); setParam("tokensAfter", params, parent); setParam("tokensBefore", params, parent); setParam("sortOrder", params, parent); return params; } public static Results spinWait(RequestThreads<CooccurConfig> threads) { Results results = new Results(threads.getMetadata()); return spinWait(threads, results); } public static Results spinWait(RequestThreads<CooccurConfig> threads, Results results) { if (threads == null || threads.empty()) return results; while (!threads.isTerminated() && !threads.empty() && !results.hitMax) { RequestWorker req = threads.next(); if (!req.isRunning()) { NamedList nl = req.getResults(); if (nl != null) { results.add(nl, req.getName()); } threads.removeLast(); } } //force complete shutdown threads.shutdownNow(); //if not enough hits, check any remaining threads that haven't been collected //for(RequestWorker req : otherRequests) while (!threads.empty() && !results.hitMax) { RequestWorker req = threads.next(); if (req != null && !req.isRunning()) { NamedList nl = req.getResults(); if (nl != null) { results.add(nl, req.getName()); } threads.removeLast(); } } threads.clear(); threads = null; return results; } public static CooccurConfig configureParams(String field, SolrParams params) { CooccurConfig config = new CooccurConfig(field); String param = params.get("targetOverlaps"); if (param != null && param.length() > 0) { try { config.setAllowTargetOverlaps(Boolean.parseBoolean(param)); } catch (Exception e) { } } param = params.get("contentDisplaySize"); if (param != null && param.length() > 0) { try { config.setMaxContextDisplaySizeChars(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("targetDisplaySize"); if (param != null && param.length() > 0) { try { config.setMaxTargetDisplaySizeChars(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("maxWindows"); if (param != null && param.length() > 0) { try { config.setMaxWindows(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("tokensAfter"); if (param != null && param.length() > 0) { try { config.setTokensAfter(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("tokensBefore"); if (param != null && param.length() > 0) { try { config.setTokensBefore(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("sortOrder"); if (param != null && param.length() > 0) { try { config.setSortOrder(ConcordanceSortOrder.valueOf(param)); } catch (Exception e) { } } param = params.get("minNGram"); if (param != null && param.length() > 0) { try { config.setMinNGram(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("maxNGram"); if (param != null && param.length() > 0) { try { config.setMaxNGram(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("minTF"); if (param != null && param.length() > 0) { try { config.setMinTermFreq(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("tokensBefore"); if (param != null && param.length() > 0) { try { config.setTokensBefore(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("tokensAfter"); if (param != null && param.length() > 0) { try { config.setTokensAfter(Integer.parseInt(param)); } catch (Exception e) { } } param = params.get("numResults"); if (param != null && param.length() > 0) { try { config.setNumResults(Integer.parseInt(param)); } catch (Exception e) { } } return config; } @SuppressWarnings({"rawtypes", "unchecked"}) public static NamedList toNamedList(List<TermIDF> results) { SimpleOrderedMap ret = new SimpleOrderedMap(); if (results.size() > 0) { NamedList nlResults = new SimpleOrderedMap(); ret.add("results", nlResults); for (TermIDF result : results) { NamedList nl = new SimpleOrderedMap(); nl.add("term", result.getTerm()); //nl.add("value", result.getValue()); nl.add("tfidf", result.getTFIDF()); nl.add("tf", result.getTermFreq()); nl.add("idf", result.getIDF()); nl.add("df", result.getDocFreq()); nlResults.add("result", nl); } } return ret; } /* public void search(IndexReader reader, String fieldName, Query query, Filter filter, Analyzer analyzer, ArrayWindowVisitor visitor, DocIdBuilder docIdBuilder ) throws IllegalArgumentException { try { ConcordanceArrayWindowSearcher searcher = new ConcordanceArrayWindowSearcher(); searcher.search(reader, fieldName, query, filter, analyzer, visitor, docIdBuilder ); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (TargetTokenNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } //xx copy consturctor instead? CooccurVisitor covisitor = (CooccurVisitor)visitor; List<TermIDF> overallResults = covisitor.getResults(); NamedList results = toNamedList(overallResults); results.add("collectionSize", reader.numDocs()); results.add("numDocsVisited", covisitor.getNumDocsVisited()); results.add("numWindowsVisited", covisitor.getNumWindowsVisited()); results.add("numResults", overallResults.size()); results.add("minTF", covisitor.getMinTermFreq()); //TODO: convert results to docIdBuilder? xx //xx return results; } */ public static String getField(SolrParams params) { String fieldName = params.get(CommonParams.FIELD); if (fieldName == null || fieldName.equalsIgnoreCase("null")) { if (fieldName == null || fieldName.equalsIgnoreCase("null")) fieldName = params.get(CommonParams.DF); if (fieldName == null || fieldName.equalsIgnoreCase("null")) { //check field list if not in field fieldName = params.get(CommonParams.FL); //TODO: change when/if request allows for multiple terms if (fieldName != null) fieldName = fieldName.split(",")[0].trim(); } } return fieldName; } @Override public void init(@SuppressWarnings("rawtypes") NamedList args) { super.init(args); // this.prop1 = invariants.get("prop1"); } @Override public String getDescription() { return "Returns tokens that frequently co-occur within concordance windows"; } @Override public String getSource() { return "https://issues.apache.org/jira/browse/SOLR-5411 - https://github.com/tballison/lucene-addons"; } @SuppressWarnings("unchecked") @Override public void handleRequestBody(SolrQueryRequest req, SolrQueryResponse rsp) throws Exception { boolean isDistrib = isDistributed(req); if (isDistrib) { System.out.println("DOING ZOO QUERY"); doZooQuery(req, rsp); } else { doQuery(req, rsp); } } @SuppressWarnings("unchecked") private void doZooQuery(SolrQueryRequest req, SolrQueryResponse rsp) throws Exception { SolrParams params = req.getParams(); String field = getField(params); CooccurConfig config = configureParams(field, params); boolean debug = params.getBool("debug", false); NamedList nlDebug = new SimpleOrderedMap(); if (debug) rsp.add("DEBUG", nlDebug); ZkController zoo = req.getCore().getCoreDescriptor().getCoreContainer().getZkController(); Set<String> nodes = zoo.getClusterState().getLiveNodes(); List<String> shards = new ArrayList<String>(nodes.size()); String thisUrl = req.getCore().getCoreDescriptor().getCoreContainer().getZkController().getBaseUrl(); for (String node : nodes) { String shard = node.replace("_", "/"); if (thisUrl.contains(shard)) continue; shard += "/" + req.getCore().getName(); shards.add(shard); } System.out.println("SHARDS SIZE: " + shards.size()); RequestThreads<CooccurConfig> threads = initRequestPump(shards, req); Results results = new Results(threads.getMetadata()); /* //skip local NamedList nl = doLocalSearch(req); for (int i = 0; i < nl.size(); i++) { System.out.println("RETURNED FROM SERVER: " + "LOCAL" + " : " + nl.getName(i) + " ; " + nl.getVal(i)); } results.add(nl, "local");*/ results = spinWait(threads, results); rsp.add(NODE, results.toNamedList()); } ; private void doQuery(SolrQueryRequest req, SolrQueryResponse rsp) throws Exception, IllegalArgumentException, ParseException, TargetTokenNotFoundException { NamedList results = doLocalSearch(req); rsp.add(NODE, results); } ; @Override protected String getHandlerName(SolrQueryRequest req) { return getHandlerName(req, DefaultName, this.getClass()); } public static class Results { long maxWindows = -1; int maxResults = -1; boolean hitMax = false; boolean maxTerms = false; int size = 0; long numDocs = 0; long numWindows = 0; int numResults = 0; HashMap<String, Keyword> keywords = new HashMap<String, Keyword>(); Results(CooccurConfig config) { this.maxWindows = config.getMaxWindows(); this.maxResults = config.getNumResults(); } Results(int maxWindows, int maxResults) { this.maxWindows = maxWindows; this.maxResults = maxResults; } void add(NamedList nl, String extra) { NamedList nlRS = (NamedList) nl.get(NODE); if (nlRS == null) nlRS = nl; numDocs += getInt("numDocs", nlRS); size += getInt("collectionSize", nlRS); numResults += getInt("numResults", nlRS); numWindows += getLong("numWindows", nlRS); hitMax = numWindows >= maxWindows; maxTerms = numResults >= maxResults; Object o = nlRS.get("results"); if (o != null) { NamedList nlRes = (NamedList) o; List<NamedList> res = nlRes.getAll("result"); for (NamedList nlTerm : res) { Keyword tmp = new Keyword(nlTerm); Keyword kw = keywords.get(tmp.term); if (kw == null) keywords.put(tmp.term, tmp); else { kw.tf += tmp.tf; kw.df += tmp.df; kw.minDF += tmp.minDF; } } } } NamedList toNamedList() { NamedList nl = new SimpleOrderedMap<>(); nl.add("hitMax", hitMax); nl.add("maxTerms", maxTerms); nl.add("numDocs", numDocs); nl.add("collectionSize", size); nl.add("numWindows", numWindows); nl.add("numResults", numResults); if (keywords.size() > 0) { //sort by new tf-idf's Integer[] idxs = new Integer[keywords.size()]; final double[] tfidfs = new double[keywords.size()]; final Keyword[] terms = new Keyword[keywords.size()]; int i = 0; for (Entry<String, Keyword> kv : keywords.entrySet()) { idxs[i] = i; Keyword kw = kv.getValue(); terms[i] = kw; tfidfs[i] = kw.tf * Math.log(size / kw.df); i++; } //System.out.println(terms); //System.out.println(Arrays.toString(terms)); //System.out.println(tfidfs); //System.out.println(Arrays.toString(tfidfs)); Arrays.sort(idxs, new Comparator<Integer>() { public int compare(Integer a, Integer b) { int ret = Double.compare(tfidfs[b], tfidfs[a]); if (ret == 0) ret = Double.compare(terms[b].df, terms[a].df); return ret; } }); NamedList<NamedList> nlResults = new SimpleOrderedMap<NamedList>(); for (i = 0; i < idxs.length && i < maxResults; i++) { Keyword kw = terms[idxs[i]]; NamedList nlKw = new SimpleOrderedMap<Object>(); nlKw.add("term", kw.term); nlKw.add("tfidf", tfidfs[idxs[i]]); nlKw.add("orig_tfidf", kw.tfidf); nlKw.add("tf", kw.tf); nlKw.add("df", kw.df); nlKw.add("minDF", kw.minDF); nlResults.add("result", nlKw); } nl.add("results", nlResults); } return nl; } } static class Keyword { String term; double tfidf = 0; long tf = 0; long df = 0; int minDF = 0; Keyword(NamedList nl) { term = nl.get("term").toString(); tf = getInt("tf", nl); df = getInt("df", nl); minDF = getInt("minDF", nl); tfidf = getDouble("tfidf", nl); } @Override public String toString() { return term; } } }