/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.ltr; import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; import java.util.concurrent.RunnableFuture; import java.util.concurrent.Semaphore; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The ranking query that is run, reranking results using the * LTRScoringModel algorithm */ public class LTRScoringQuery extends Query { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); // contains a description of the model final private LTRScoringModel ltrScoringModel; final private boolean extractAllFeatures; final private LTRThreadModule ltrThreadMgr; final private Semaphore querySemaphore; // limits the number of threads per query, so that multiple requests can be serviced simultaneously // feature logger to output the features. private FeatureLogger fl; // Map of external parameters, such as query intent, that can be used by // features final private Map<String,String[]> efi; // Original solr query used to fetch matching documents private Query originalQuery; // Original solr request private SolrQueryRequest request; public LTRScoringQuery(LTRScoringModel ltrScoringModel) { this(ltrScoringModel, Collections.<String,String[]>emptyMap(), false, null); } public LTRScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) { this(ltrScoringModel, Collections.<String, String[]>emptyMap(), extractAllFeatures, null); } public LTRScoringQuery(LTRScoringModel ltrScoringModel, Map<String, String[]> externalFeatureInfo, boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) { this.ltrScoringModel = ltrScoringModel; this.efi = externalFeatureInfo; this.extractAllFeatures = extractAllFeatures; this.ltrThreadMgr = ltrThreadMgr; if (this.ltrThreadMgr != null) { this.querySemaphore = this.ltrThreadMgr.createQuerySemaphore(); } else{ this.querySemaphore = null; } } public LTRScoringModel getScoringModel() { return ltrScoringModel; } public void setFeatureLogger(FeatureLogger fl) { this.fl = fl; } public FeatureLogger getFeatureLogger() { return fl; } public void setOriginalQuery(Query originalQuery) { this.originalQuery = originalQuery; } public Query getOriginalQuery() { return originalQuery; } public Map<String,String[]> getExternalFeatureInfo() { return efi; } public void setRequest(SolrQueryRequest request) { this.request = request; } public SolrQueryRequest getRequest() { return request; } @Override public int hashCode() { final int prime = 31; int result = classHash(); result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode()); result = (prime * result) + ((originalQuery == null) ? 0 : originalQuery.hashCode()); if (efi == null) { result = (prime * result) + 0; } else { for (final Map.Entry<String,String[]> entry : efi.entrySet()) { final String key = entry.getKey(); final String[] values = entry.getValue(); result = (prime * result) + key.hashCode(); result = (prime * result) + Arrays.hashCode(values); } } result = (prime * result) + this.toString().hashCode(); return result; } @Override public boolean equals(Object o) { return sameClassAs(o) && equalsTo(getClass().cast(o)); } private boolean equalsTo(LTRScoringQuery other) { if (ltrScoringModel == null) { if (other.ltrScoringModel != null) { return false; } } else if (!ltrScoringModel.equals(other.ltrScoringModel)) { return false; } if (originalQuery == null) { if (other.originalQuery != null) { return false; } } else if (!originalQuery.equals(other.originalQuery)) { return false; } if (efi == null) { if (other.efi != null) { return false; } } else { if (other.efi == null || efi.size() != other.efi.size()) { return false; } for(final Map.Entry<String,String[]> entry : efi.entrySet()) { final String key = entry.getKey(); final String[] otherValues = other.efi.get(key); if (otherValues == null || !Arrays.equals(otherValues,entry.getValue())) { return false; } } } return true; } @Override public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { final Collection<Feature> modelFeatures = ltrScoringModel.getFeatures(); final Collection<Feature> allFeatures = ltrScoringModel.getAllFeatures(); int modelFeatSize = modelFeatures.size(); Collection<Feature> features = null; if (this.extractAllFeatures) { features = allFeatures; } else{ features = modelFeatures; } final Feature.FeatureWeight[] extractedFeatureWeights = new Feature.FeatureWeight[features.size()]; final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize]; List<Feature.FeatureWeight > featureWeights = new ArrayList<>(features.size()); if (querySemaphore == null) { createWeights(searcher, needsScores, featureWeights, features); } else{ createWeightsParallel(searcher, needsScores, featureWeights, features); } int i=0, j = 0; if (this.extractAllFeatures) { for (final Feature.FeatureWeight fw : featureWeights) { extractedFeatureWeights[i++] = fw; } for (final Feature f : modelFeatures){ modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()]; // we can lookup by featureid because all features will be extracted when this.extractAllFeatures is set } } else{ for (final Feature.FeatureWeight fw: featureWeights){ extractedFeatureWeights[i++] = fw; modelFeaturesWeights[j++] = fw; } } return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size()); } private void createWeights(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws IOException { final SolrQueryRequest req = getRequest(); // since the feature store is a linkedhashmap order is preserved for (final Feature f : features) { try{ Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi); featureWeights.add(fw); } catch (final Exception e) { throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + e.getMessage(), e); } } } private class CreateWeightCallable implements Callable<Feature.FeatureWeight>{ final private Feature f; final private IndexSearcher searcher; final private boolean needsScores; final private SolrQueryRequest req; public CreateWeightCallable(Feature f, IndexSearcher searcher, boolean needsScores, SolrQueryRequest req){ this.f = f; this.searcher = searcher; this.needsScores = needsScores; this.req = req; } @Override public Feature.FeatureWeight call() throws Exception{ try { Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi); return fw; } catch (final Exception e) { throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + e.getMessage(), e); } finally { querySemaphore.release(); ltrThreadMgr.releaseLTRSemaphore(); } } } // end of call CreateWeightCallable private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws RuntimeException { final SolrQueryRequest req = getRequest(); List<Future<Feature.FeatureWeight> > futures = new ArrayList<>(features.size()); try{ for (final Feature f : features) { CreateWeightCallable callable = new CreateWeightCallable(f, searcher, needsScores, req); RunnableFuture<Feature.FeatureWeight> runnableFuture = new FutureTask<>(callable); querySemaphore.acquire(); // always acquire before the ltrSemaphore is acquired, to guarantee a that the current query is within the limit for max. threads ltrThreadMgr.acquireLTRSemaphore();//may block and/or interrupt ltrThreadMgr.execute(runnableFuture);//releases semaphore when done futures.add(runnableFuture); } //Loop over futures to get the feature weight objects for (final Future<Feature.FeatureWeight> future : futures) { featureWeights.add(future.get()); // future.get() will block if the job is still running } } catch (Exception e) { // To catch InterruptedException and ExecutionException log.info("Error while creating weights in LTR: InterruptedException", e); throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e); } } @Override public String toString(String field) { return field; } public static class FeatureInfo { final private String name; private float value; private boolean used; FeatureInfo(String n, float v, boolean u){ name = n; value = v; used = u; } public void setValue(float value){ this.value = value; } public String getName(){ return name; } public float getValue(){ return value; } public boolean isUsed(){ return used; } public void setUsed(boolean used){ this.used = used; } } public class ModelWeight extends Weight { // List of the model's features used for scoring. This is a subset of the // features used for logging. final private Feature.FeatureWeight[] modelFeatureWeights; final private float[] modelFeatureValuesNormalized; final private Feature.FeatureWeight[] extractedFeatureWeights; // List of all the feature names, values - used for both scoring and logging /* * What is the advantage of using a hashmap here instead of an array of objects? * A set of arrays was used earlier and the elements were accessed using the featureId. * With the updated logic to create weights selectively, * the number of elements in the array can be fewer than the total number of features. * When [features] are not requested, only the model features are extracted. * In this case, the indexing by featureId, fails. For this reason, * we need a map which holds just the features that were triggered by the documents in the result set. * */ final private FeatureInfo[] featuresInfo; /* * @param modelFeatureWeights * - should be the same size as the number of features used by the model * @param extractedFeatureWeights * - if features are requested from the same store as model feature store, * this will be the size of total number of features in the model feature store * else, this will be the size of the modelFeatureWeights * @param allFeaturesSize * - total number of feature in the feature store used by this model */ public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) { super(LTRScoringQuery.this); this.extractedFeatureWeights = extractedFeatureWeights; this.modelFeatureWeights = modelFeatureWeights; this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; this.featuresInfo = new FeatureInfo[allFeaturesSize]; setFeaturesInfo(); } private void setFeaturesInfo(){ for (int i = 0; i < extractedFeatureWeights.length;++i){ String featName = extractedFeatureWeights[i].getName(); int featId = extractedFeatureWeights[i].getIndex(); float value = extractedFeatureWeights[i].getDefaultValue(); featuresInfo[featId] = new FeatureInfo(featName,value,false); } } public FeatureInfo[] getFeaturesInfo(){ return featuresInfo; } // for test use Feature.FeatureWeight[] getModelFeatureWeights() { return modelFeatureWeights; } // for test use float[] getModelFeatureValuesNormalized() { return modelFeatureValuesNormalized; } // for test use Feature.FeatureWeight[] getExtractedFeatureWeights() { return extractedFeatureWeights; } /** * Goes through all the stored feature values, and calculates the normalized * values for all the features that will be used for scoring. * Then calculate and return the model's score. */ private float makeNormalizedFeaturesAndScore() { int pos = 0; for (final Feature.FeatureWeight feature : modelFeatureWeights) { final int featureId = feature.getIndex(); FeatureInfo fInfo = featuresInfo[featureId]; if (fInfo.isUsed()) { // not checking for finfo == null as that would be a bug we should catch modelFeatureValuesNormalized[pos] = fInfo.getValue(); } else { modelFeatureValuesNormalized[pos] = feature.getDefaultValue(); } pos++; } ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized); return ltrScoringModel.score(modelFeatureValuesNormalized); } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { final Explanation[] explanations = new Explanation[this.featuresInfo.length]; for (final Feature.FeatureWeight feature : extractedFeatureWeights) { explanations[feature.getIndex()] = feature.explain(context, doc); } final List<Explanation> featureExplanations = new ArrayList<>(); for (int idx = 0 ;idx < modelFeatureWeights.length; ++idx) { final Feature.FeatureWeight f = modelFeatureWeights[idx]; Explanation e = ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx); featureExplanations.add(e); } final ModelScorer bs = scorer(context); bs.iterator().advance(doc); final float finalScore = bs.score(); return ltrScoringModel.explain(context, doc, finalScore, featureExplanations); } @Override public void extractTerms(Set<Term> terms) { for (final Feature.FeatureWeight feature : extractedFeatureWeights) { feature.extractTerms(terms); } } protected void reset() { for (int i = 0; i < extractedFeatureWeights.length;++i){ int featId = extractedFeatureWeights[i].getIndex(); float value = extractedFeatureWeights[i].getDefaultValue(); featuresInfo[featId].setValue(value); // need to set default value everytime as the default value is used in 'dense' mode even if used=false featuresInfo[featId].setUsed(false); } } @Override public ModelScorer scorer(LeafReaderContext context) throws IOException { final List<Feature.FeatureWeight.FeatureScorer> featureScorers = new ArrayList<Feature.FeatureWeight.FeatureScorer>( extractedFeatureWeights.length); for (final Feature.FeatureWeight featureWeight : extractedFeatureWeights) { final Feature.FeatureWeight.FeatureScorer scorer = featureWeight.scorer(context); if (scorer != null) { featureScorers.add(scorer); } } // Always return a ModelScorer, even if no features match, because we // always need to call // score on the model for every document, since 0 features matching could // return a // non 0 score for a given model. ModelScorer mscorer = new ModelScorer(this, featureScorers); return mscorer; } public class ModelScorer extends Scorer { final private DocInfo docInfo; final private Scorer featureTraversalScorer; public DocInfo getDocInfo() { return docInfo; } public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { super(weight); docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) { subSocer.setDocInfo(docInfo); } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases featureTraversalScorer = new DenseModelScorer(weight, featureScorers); } else { featureTraversalScorer = new SparseModelScorer(weight, featureScorers); } } @Override public Collection<ChildScorer> getChildren() throws IOException { return featureTraversalScorer.getChildren(); } @Override public int docID() { return featureTraversalScorer.docID(); } @Override public float score() throws IOException { return featureTraversalScorer.score(); } @Override public int freq() throws IOException { return featureTraversalScorer.freq(); } @Override public DocIdSetIterator iterator() { return featureTraversalScorer.iterator(); } private class SparseModelScorer extends Scorer { final private DisiPriorityQueue subScorers; final private ScoringQuerySparseIterator itr; private int targetDoc = -1; private int activeDoc = -1; private SparseModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { super(weight); if (featureScorers.size() <= 1) { throw new IllegalArgumentException( "There must be at least 2 subScorers"); } subScorers = new DisiPriorityQueue(featureScorers.size()); for (final Scorer scorer : featureScorers) { final DisiWrapper w = new DisiWrapper(scorer); subScorers.add(w); } itr = new ScoringQuerySparseIterator(subScorers); } @Override public int docID() { return itr.docID(); } @Override public float score() throws IOException { final DisiWrapper topList = subScorers.topList(); // If target doc we wanted to advance to matches the actual doc // the underlying features advanced to, perform the feature // calculations, // otherwise just continue with the model's scoring process with empty // features. reset(); if (activeDoc == targetDoc) { for (DisiWrapper w = topList; w != null; w = w.next) { final Scorer subScorer = w.scorer; Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight(); final int featureId = scFW.getIndex(); featuresInfo[featureId].setValue(subScorer.score()); featuresInfo[featureId].setUsed(true); } } return makeNormalizedFeaturesAndScore(); } @Override public int freq() throws IOException { final DisiWrapper subMatches = subScorers.topList(); int freq = 1; for (DisiWrapper w = subMatches.next; w != null; w = w.next) { freq += 1; } return freq; } @Override public DocIdSetIterator iterator() { return itr; } @Override public final Collection<ChildScorer> getChildren() { final ArrayList<ChildScorer> children = new ArrayList<>(); for (final DisiWrapper scorer : subScorers) { children.add(new ChildScorer(scorer.scorer, "SHOULD")); } return children; } private class ScoringQuerySparseIterator extends DisjunctionDISIApproximation { public ScoringQuerySparseIterator(DisiPriorityQueue subIterators) { super(subIterators); } @Override public final int nextDoc() throws IOException { if (activeDoc == targetDoc) { activeDoc = super.nextDoc(); } else if (activeDoc < targetDoc) { activeDoc = super.advance(targetDoc + 1); } return ++targetDoc; } @Override public final int advance(int target) throws IOException { // If target doc we wanted to advance to matches the actual doc // the underlying features advanced to, perform the feature // calculations, // otherwise just continue with the model's scoring process with // empty features. if (activeDoc < target) { activeDoc = super.advance(target); } targetDoc = target; return targetDoc; } } } private class DenseModelScorer extends Scorer { private int activeDoc = -1; // The doc that our scorer's are actually at private int targetDoc = -1; // The doc we were most recently told to go to private int freq = -1; final private List<Feature.FeatureWeight.FeatureScorer> featureScorers; private DenseModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) { super(weight); this.featureScorers = featureScorers; } @Override public int docID() { return targetDoc; } @Override public float score() throws IOException { reset(); freq = 0; if (targetDoc == activeDoc) { for (final Scorer scorer : featureScorers) { if (scorer.docID() == activeDoc) { freq++; Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); final int featureId = scFW.getIndex(); featuresInfo[featureId].setValue(scorer.score()); featuresInfo[featureId].setUsed(true); } } } return makeNormalizedFeaturesAndScore(); } @Override public final Collection<ChildScorer> getChildren() { final ArrayList<ChildScorer> children = new ArrayList<>(); for (final Scorer scorer : featureScorers) { children.add(new ChildScorer(scorer, "SHOULD")); } return children; } @Override public int freq() throws IOException { return freq; } @Override public DocIdSetIterator iterator() { return new DenseIterator(); } private class DenseIterator extends DocIdSetIterator { @Override public int docID() { return targetDoc; } @Override public int nextDoc() throws IOException { if (activeDoc <= targetDoc) { activeDoc = NO_MORE_DOCS; for (final Scorer scorer : featureScorers) { if (scorer.docID() != NO_MORE_DOCS) { activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); } } } return ++targetDoc; } @Override public int advance(int target) throws IOException { if (activeDoc < target) { activeDoc = NO_MORE_DOCS; for (final Scorer scorer : featureScorers) { if (scorer.docID() != NO_MORE_DOCS) { activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); } } } targetDoc = target; return target; } @Override public long cost() { long sum = 0; for (int i = 0; i < featureScorers.size(); i++) { sum += featureScorers.get(i).iterator().cost(); } return sum; } } } } } }