package querqy.solr;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.*;
import org.apache.solr.common.SolrException;
import org.apache.solr.handler.component.MergeStrategy;
import org.apache.solr.search.QueryCommand;
import org.apache.solr.search.RankQuery;
import java.io.IOException;
import java.util.Set;
/**
* Created by rene on 01/09/2016.
*/
public class QuerqyReRankQuery extends RankQuery {
protected static final Query DEFAULT_MAIN_QUERY = new MatchAllDocsQuery();
protected Query mainQuery = DEFAULT_MAIN_QUERY;
protected final Query reRankQuery;
protected final int reRankNumDocs;
protected final double reRankWeight;
public QuerqyReRankQuery(final Query mainQuery, final Query reRankQuery, final int reRankNumDocs, final double reRankWeight) {
super();
this.reRankQuery = reRankQuery;
this.reRankNumDocs = reRankNumDocs;
this.reRankWeight = reRankWeight;
wrap(mainQuery);
}
@Override
public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException {
return new ReRankCollector(reRankNumDocs, len, reRankQuery, reRankWeight, cmd, searcher);
}
@Override
public RankQuery wrap(Query mainQuery) {
if (mainQuery == null) {
throw new IllegalArgumentException("Cannot wrap null");
}
this.mainQuery = mainQuery;
return this;
}
@Override
public MergeStrategy getMergeStrategy() {
return null;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query m = mainQuery.rewrite(reader);
Query r = reRankQuery.rewrite(reader);
if (m != mainQuery || r != reRankQuery) {
return new QuerqyReRankQuery(m, r, reRankNumDocs, reRankWeight);
}
return super.rewrite(reader);
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
return new ReRankWeight(mainQuery, reRankQuery, reRankWeight, searcher, needsScores);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!sameClassAs(o)) return false;
QuerqyReRankQuery that = (QuerqyReRankQuery) o;
if (reRankNumDocs != that.reRankNumDocs) return false;
if (Double.compare(that.reRankWeight, reRankWeight) != 0) return false;
if (!mainQuery.equals(that.mainQuery)) return false;
return reRankQuery.equals(that.reRankQuery);
}
@Override
public int hashCode() {
int prime = 31;
int result = classHash();
result = prime * result + mainQuery.hashCode();
result = prime * result + reRankQuery.hashCode();
result = prime * result + reRankNumDocs;
final long temp = Double.doubleToLongBits(reRankWeight);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
private class ReRankWeight extends Weight{
private Query reRankQuery;
private IndexSearcher searcher;
private Weight mainWeight;
private Weight rankWeight;
private double reRankWeight;
public ReRankWeight(Query mainQuery, Query reRankQuery, double reRankWeight, IndexSearcher searcher, boolean needsScores) throws IOException {
super(mainQuery);
this.reRankQuery = reRankQuery;
this.searcher = searcher;
this.reRankWeight = reRankWeight;
this.mainWeight = mainQuery.createWeight(searcher, needsScores);
this.rankWeight = reRankQuery.createWeight(searcher, true);
}
@Override
public void extractTerms(Set<Term> terms) {
this.mainWeight.extractTerms(terms);
this.rankWeight.extractTerms(terms);
}
public float getValueForNormalization() throws IOException {
return mainWeight.getValueForNormalization() + rankWeight.getValueForNormalization();
}
public Scorer scorer(LeafReaderContext context) throws IOException {
return mainWeight.scorer(context);
}
public void normalize(float norm, float topLevelBoost) {
mainWeight.normalize(norm, topLevelBoost);
}
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Explanation mainExplain = mainWeight.explain(context, doc);
return new QueryRescorer(reRankQuery) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += reRankWeight * secondPassScore;
}
return score;
}
}.explain(searcher, mainExplain, context.docBase+doc);
}
}
public class ReRankCollector extends TopDocsCollector {
private Query reRankQuery;
private TopDocsCollector mainCollector;
private IndexSearcher searcher;
private int reRankNumDocs;
private int length;
private double reRankWeight;
public ReRankCollector(int reRankNumDocs,
int length,
Query reRankQuery,
double reRankWeight,
QueryCommand cmd,
IndexSearcher searcher) throws IOException {
super(null);
this.reRankQuery = reRankQuery;
this.reRankNumDocs = reRankNumDocs;
this.length = length;
Sort sort = cmd.getSort();
if (sort == null) {
this.mainCollector = TopScoreDocCollector.create(Math.max(reRankNumDocs, length));
} else {
sort = sort.rewrite(searcher);
this.mainCollector = TopFieldCollector.create(sort, Math.max(reRankNumDocs, length), false, true, true);
}
this.searcher = searcher;
this.reRankWeight = reRankWeight;
}
public int getTotalHits() {
return mainCollector.getTotalHits();
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return mainCollector.getLeafCollector(context);
}
@Override
public boolean needsScores() {
return true;
}
public TopDocs topDocs(int start, int howMany) {
try {
TopDocs mainDocs = mainCollector.topDocs(0, Math.max(reRankNumDocs, length));
if (mainDocs.totalHits == 0 || mainDocs.scoreDocs.length == 0) {
return mainDocs;
}
ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
/*
* Create the array for the reRankScoreDocs.
*/
ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankNumDocs)];
/*
* Copy the initial results into the reRankScoreDocs array.
*/
System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length);
mainDocs.scoreDocs = reRankScoreDocs;
TopDocs rescoredDocs = new QueryRescorer(reRankQuery) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += reRankWeight * secondPassScore;
}
return score;
}
}.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
//Lower howMany to return if we've collected fewer documents.
howMany = Math.min(howMany, mainScoreDocs.length);
if (howMany == rescoredDocs.scoreDocs.length) {
return rescoredDocs; // Just return the rescoredDocs
} else if (howMany > rescoredDocs.scoreDocs.length) {
//We need to return more then we've reRanked, so create the combined page.
ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
//lay down the initial docs
System.arraycopy(mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length);
//overlay the rescoreds docs
System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, rescoredDocs.scoreDocs.length);
rescoredDocs.scoreDocs = scoreDocs;
return rescoredDocs;
} else {
//We've rescored more then we need to return.
ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany);
rescoredDocs.scoreDocs = scoreDocs;
return rescoredDocs;
}
} catch (Exception e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
}
}