/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.client.solrj.io.stream; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.stream.Stream; import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; import org.apache.solr.client.solrj.io.stream.expr.Explanation; import org.apache.solr.client.solrj.io.stream.expr.Expressible; import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.request.QueryRequest; import org.apache.solr.client.solrj.response.QueryResponse; import org.apache.solr.common.cloud.ClusterState; import org.apache.solr.common.cloud.Replica; import org.apache.solr.common.cloud.Slice; import org.apache.solr.common.cloud.ZkCoreNodeProps; import org.apache.solr.common.cloud.ZkStateReader; import org.apache.solr.common.params.ModifiableSolrParams; import org.apache.solr.common.util.ExecutorUtil; import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.SolrjNamedThreadFactory; import static org.apache.solr.common.params.CommonParams.DISTRIB; import static org.apache.solr.common.params.CommonParams.ID; public class FeaturesSelectionStream extends TupleStream implements Expressible{ private static final long serialVersionUID = 1; protected String zkHost; protected String collection; protected Map<String,String> params; protected Iterator<Tuple> tupleIterator; protected String field; protected String outcome; protected String featureSet; protected int positiveLabel; protected int numTerms; protected transient SolrClientCache cache; protected transient boolean isCloseCache; protected transient CloudSolrClient cloudSolrClient; protected transient StreamContext streamContext; protected ExecutorService executorService; public FeaturesSelectionStream(String zkHost, String collectionName, Map params, String field, String outcome, String featureSet, int positiveLabel, int numTerms) throws IOException { init(collectionName, zkHost, params, field, outcome, featureSet, positiveLabel, numTerms); } /** * logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20") **/ public FeaturesSelectionStream(StreamExpression expression, StreamFactory factory) throws IOException{ // grab all parameters out String collectionName = factory.getValueOperand(expression, 0); List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression); StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost"); // Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice if(expression.getParameters().size() != 1 + namedParams.size()){ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression)); } // Collection Name if(null == collectionName){ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression)); } // Named parameters - passed directly to solr as solrparams if(0 == namedParams.size()){ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression)); } Map<String,String> params = new HashMap<String,String>(); for(StreamExpressionNamedParameter namedParam : namedParams){ if(!namedParam.getName().equals("zkHost")) { params.put(namedParam.getName(), namedParam.getParameter().toString().trim()); } } String fieldParam = params.get("field"); if(fieldParam != null) { params.remove("field"); } else { throw new IOException("field param cannot be null for FeaturesSelectionStream"); } String outcomeParam = params.get("outcome"); if(outcomeParam != null) { params.remove("outcome"); } else { throw new IOException("outcome param cannot be null for FeaturesSelectionStream"); } String featureSetParam = params.get("featureSet"); if(featureSetParam != null) { params.remove("featureSet"); } else { throw new IOException("featureSet param cannot be null for FeaturesSelectionStream"); } String positiveLabelParam = params.get("positiveLabel"); int positiveLabel = 1; if(positiveLabelParam != null) { params.remove("positiveLabel"); positiveLabel = Integer.parseInt(positiveLabelParam); } String numTermsParam = params.get("numTerms"); int numTerms = 1; if(numTermsParam != null) { numTerms = Integer.parseInt(numTermsParam); params.remove("numTerms"); } else { throw new IOException("numTerms param cannot be null for FeaturesSelectionStream"); } // zkHost, optional - if not provided then will look into factory list to get String zkHost = null; if(null == zkHostExpression){ zkHost = factory.getCollectionZkHost(collectionName); } else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){ zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue(); } if(null == zkHost){ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName)); } // We've got all the required items init(collectionName, zkHost, params, fieldParam, outcomeParam, featureSetParam, positiveLabel, numTerms); } @Override public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { // functionName(collectionName, param1, param2, ..., paramN, sort="comp", [aliases="field=alias,..."]) // function name StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); // collection expression.addParameter(collection); // parameters for(Map.Entry<String,String> param : params.entrySet()){ expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue())); } expression.addParameter(new StreamExpressionNamedParameter("field", field)); expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome)); expression.addParameter(new StreamExpressionNamedParameter("featureSet", featureSet)); expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", String.valueOf(positiveLabel))); expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(numTerms))); // zkHost expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost)); return expression; } private void init(String collectionName, String zkHost, Map params, String field, String outcome, String featureSet, int positiveLabel, int numTopTerms) throws IOException { this.zkHost = zkHost; this.collection = collectionName; this.params = params; this.field = field; this.outcome = outcome; this.featureSet = featureSet; this.positiveLabel = positiveLabel; this.numTerms = numTopTerms; } public void setStreamContext(StreamContext context) { this.cache = context.getSolrClientCache(); this.streamContext = context; } /** * Opens the CloudSolrStream * ***/ public void open() throws IOException { if (cache == null) { isCloseCache = true; cache = new SolrClientCache(); } else { isCloseCache = false; } this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost); this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream")); } public List<TupleStream> children() { return null; } private List<String> getShardUrls() throws IOException { try { ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader(); Collection<Slice> slices = CloudSolrStream.getSlices(this.collection, zkStateReader, false); ClusterState clusterState = zkStateReader.getClusterState(); Set<String> liveNodes = clusterState.getLiveNodes(); List<String> baseUrls = new ArrayList<>(); for(Slice slice : slices) { Collection<Replica> replicas = slice.getReplicas(); List<Replica> shuffler = new ArrayList<>(); for(Replica replica : replicas) { if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) { shuffler.add(replica); } } Collections.shuffle(shuffler, new Random()); Replica rep = shuffler.get(0); ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep); String url = zkProps.getCoreUrl(); baseUrls.add(url); } return baseUrls; } catch (Exception e) { throw new IOException(e); } } private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException { List<Future<NamedList>> futures = new ArrayList<>(); for (String baseUrl : baseUrls) { FeaturesSelectionCall lc = new FeaturesSelectionCall(baseUrl, this.params, this.field, this.outcome); Future<NamedList> future = executorService.submit(lc); futures.add(future); } return futures; } public void close() throws IOException { if (isCloseCache) { cache.close(); } executorService.shutdown(); } /** Return the stream sort - ie, the order in which records are returned */ public StreamComparator getStreamSort(){ return null; } @Override public Explanation toExplanation(StreamFactory factory) throws IOException { return new StreamExplanation(getStreamNodeId().toString()) .withFunctionName(factory.getFunctionName(this.getClass())) .withImplementingClass(this.getClass().getName()) .withExpressionType(Explanation.ExpressionType.STREAM_DECORATOR) .withExpression(toExpression(factory).toString()); } public Tuple read() throws IOException { try { if (tupleIterator == null) { Map<String, Double> termScores = new HashMap<>(); Map<String, Long> docFreqs = new HashMap<>(); long numDocs = 0; for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) { NamedList resp = getTopTermsCall.get(); NamedList<Double> shardTopTerms = (NamedList<Double>)resp.get("featuredTerms"); NamedList<Integer> shardDocFreqs = (NamedList<Integer>)resp.get("docFreq"); numDocs += (Integer)resp.get("numDocs"); for (int i = 0; i < shardTopTerms.size(); i++) { String term = shardTopTerms.getName(i); double score = shardTopTerms.getVal(i); int docFreq = shardDocFreqs.get(term); double prevScore = termScores.containsKey(term) ? termScores.get(term) : 0; long prevDocFreq = docFreqs.containsKey(term) ? docFreqs.get(term) : 0; termScores.put(term, prevScore + score); docFreqs.put(term, prevDocFreq + docFreq); } } List<Tuple> tuples = new ArrayList<>(numTerms); termScores = sortByValue(termScores); int index = 0; for (Map.Entry<String, Double> termScore : termScores.entrySet()) { if (tuples.size() == numTerms) break; index++; Map map = new HashMap(); map.put(ID, featureSet + "_" + index); map.put("index_i", index); map.put("term_s", termScore.getKey()); map.put("score_f", termScore.getValue()); map.put("featureSet_s", featureSet); long docFreq = docFreqs.get(termScore.getKey()); double d = Math.log(((double)numDocs / (double)(docFreq + 1))); map.put("idf_d", d); tuples.add(new Tuple(map)); } Map map = new HashMap(); map.put("EOF", true); tuples.add(new Tuple(map)); tupleIterator = tuples.iterator(); } return tupleIterator.next(); } catch(Exception e) { throw new IOException(e); } } private <K, V extends Comparable<? super V>> Map<K, V> sortByValue( Map<K, V> map ) { Map<K, V> result = new LinkedHashMap<>(); Stream<Map.Entry<K, V>> st = map.entrySet().stream(); st.sorted( Map.Entry.comparingByValue( (c1, c2) -> c2.compareTo(c1) ) ).forEachOrdered( e -> result.put(e.getKey(), e.getValue()) ); return result; } protected class FeaturesSelectionCall implements Callable<NamedList> { private String baseUrl; private String outcome; private String field; private Map<String, String> paramsMap; public FeaturesSelectionCall(String baseUrl, Map<String, String> paramsMap, String field, String outcome) { this.baseUrl = baseUrl; this.outcome = outcome; this.field = field; this.paramsMap = paramsMap; } public NamedList<Double> call() throws Exception { ModifiableSolrParams params = new ModifiableSolrParams(); HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl); params.add(DISTRIB, "false"); params.add("fq","{!igain}"); for(String key : paramsMap.keySet()) { params.add(key, paramsMap.get(key)); } params.add("outcome", outcome); params.add("positiveLabel", Integer.toString(positiveLabel)); params.add("field", field); params.add("numTerms", String.valueOf(numTerms)); QueryRequest request= new QueryRequest(params); QueryResponse response = request.process(solrClient); NamedList res = response.getResponse(); return res; } } }