package experiments.collective.entdoccentric.LTR;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.similarities.Similarity;
import experiments.collective.entdoccentric.LTR.LearnToRankQuery.LearnToRankWeight;
/**
* An alternative to BooleanScorer that also allows a minimum number of optional
* scorers that should match. <br>
* Implements skipTo(), and has no limitations on the numbers of added scorers. <br>
* Uses ConjunctionScorer, DisjunctionScorer, ReqOptScorer and ReqExclScorer.
*/
class LearnToRankScorer extends Scorer {
private final List<Scorer> requiredScorers;
private final List<Scorer> optionalScorers;
private final LearnToRankClause[] optionalClauses;
private final LearnToRankClause[] requiredClauses;
private final int docBase;
/**
* The scorer to which all scoring will be delegated, except for computing
* and using the coordination factor.
*/
private final Scorer countingSumScorer;
private int doc = -1;
/**
* Creates a {@link Scorer} with the given similarity and lists of required,
* prohibited and optional scorers. In no required scorers are added, at
* least one of the optional scorers will have to match during the search.
*
* @param weight
* The BooleanWeight to be used.
* @param disableCoord
* If this parameter is true, coordination level matching (
* {@link Similarity#coord(int, int)}) is not used.
* @param minNrShouldMatch
* The minimum number of optional added scorers that should match
* during the search. In case no required scorers are added, at
* least one of the optional scorers will have to match during
* the search.
* @param required
* the list of required scorers.
* @param prohibited
* the list of prohibited scorers.
* @param optional
* the list of optional scorers.
*/
public LearnToRankScorer(LearnToRankWeight weight, List<Scorer> optional,
List<Scorer> required, LearnToRankClause[] optionalClauses,
LearnToRankClause[] requiredClauses, int docBase)
throws IOException {
super(weight);
this.docBase = docBase;
this.optionalScorers = optional;
this.requiredScorers = required;
this.optionalClauses = optionalClauses;
this.requiredClauses = requiredClauses;
countingSumScorer = createSumScorer();
}
/** Count a scorer as a single match. */
private class SingleMatchScorer extends Scorer {
private Scorer scorer;
private int lastScoredDoc = -1;
// Save the score of lastScoredDoc, so that we don't compute it more
// than
// once in score().
private float lastDocScore = Float.NaN;
private LearnToRankClause clause;
SingleMatchScorer(Scorer scorer, LearnToRankClause clause) {
super(scorer.getWeight());
this.scorer = scorer;
this.clause = clause;
}
@Override
public float score() throws IOException {
int doc = docID();
if (doc > lastScoredDoc) {
lastDocScore = scorer.score();
lastScoredDoc = doc;
}
float val = lastDocScore * clause.getWeight();
clause.addFeatureValue(docBase, doc, val);
return val;
}
@Override
public int freq() throws IOException {
return 1;
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return scorer.advance(target);
}
}
private Scorer createSumScorer() throws IOException {
return (requiredScorers.size() == 0) ? makeCountingSumScorerNoReq()
: makeCountingSumScorerSomeReq();
}
private Scorer countingDisjunctionSumScorer(final List<Scorer> scorers,
int minNrShouldMatch) throws IOException {
// each scorer from the list counted as a single matcher
return new DisjunctionSumScorer(weight, scorers, minNrShouldMatch,
optionalClauses, docBase) {
private int lastScoredDoc = -1;
// Save the score of lastScoredDoc, so that we don't compute it more
// than
// once in score().
private float lastDocScore = Float.NaN;
@Override
public float score() throws IOException {
int doc = docID();
if (doc > lastScoredDoc) {
lastDocScore = super.score();
lastScoredDoc = doc;
}
return lastDocScore;
}
};
}
private Scorer countingConjunctionSumScorer(List<Scorer> requiredScorers,
LearnToRankClause[] requiredClauses) throws IOException {
// each scorer from the list counted as a single matcher
return new ConjunctionScorer(weight, docBase, requiredScorers,
requiredClauses) {
private int lastScoredDoc = -1;
private float lastDocScore = Float.NaN;
@Override
public float score() throws IOException {
int doc = docID();
if (doc >= lastScoredDoc) {
if (doc > lastScoredDoc) {
lastDocScore = super.score();
lastScoredDoc = doc;
}
}
return lastDocScore;
}
};
}
/**
* Returns the scorer to be used for match counting and score summing. Uses
* requiredScorers, optionalScorers.
*/
private Scorer makeCountingSumScorerNoReq() throws IOException { // No // required
// scorers
// minNrShouldMatch optional scorers are required, but at least 1
int nrOptRequired = 1;
Scorer requiredCountingSumScorer;
if (optionalScorers.size() > nrOptRequired) {
requiredCountingSumScorer = countingDisjunctionSumScorer(
optionalScorers, nrOptRequired);
} else {
requiredCountingSumScorer = new SingleMatchScorer(
optionalScorers.get(0), optionalClauses[0]);
}
return requiredCountingSumScorer;
}
private Scorer makeCountingSumScorerSomeReq() throws IOException {
Scorer requiredCountingSumScorer = requiredScorers.size() == 1 ? new SingleMatchScorer(
requiredScorers.get(0), requiredClauses[0])
: countingConjunctionSumScorer(requiredScorers, requiredClauses);
if (optionalScorers.size() == 0) {
return requiredCountingSumScorer;
} else {
return new ReqOptSumScorer(requiredCountingSumScorer,
optionalScorers.size() == 1 ? new SingleMatchScorer(
optionalScorers.get(0), optionalClauses[0])
: countingDisjunctionSumScorer(optionalScorers, 1));
}
}
/**
* Scores and collects all matching documents.
*
* @param collector
* The collector to which all matching documents are passed
* through.
*/
@Override
public void score(Collector collector) throws IOException {
collector.setScorer(this);
while ((doc = countingSumScorer.nextDoc()) != NO_MORE_DOCS) {
collector.collect(doc);
}
}
@Override
public boolean score(Collector collector, int max, int firstDocID)
throws IOException {
doc = firstDocID;
collector.setScorer(this);
while (doc < max) {
collector.collect(doc);
doc = countingSumScorer.nextDoc();
}
return doc != NO_MORE_DOCS;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return doc = countingSumScorer.nextDoc();
}
@Override
public float score() throws IOException {
float sum = countingSumScorer.score();
return sum;
}
@Override
public int freq() throws IOException {
return countingSumScorer.freq();
}
@Override
public int advance(int target) throws IOException {
return doc = countingSumScorer.advance(target);
}
@Override
public Collection<ChildScorer> getChildren() {
ArrayList<ChildScorer> children = new ArrayList<ChildScorer>();
for (Scorer s : optionalScorers) {
children.add(new ChildScorer(s, "SHOULD"));
}
return children;
}
}