/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.stream.Stream;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.PriorityQueue;
final class Boolean2ScorerSupplier extends ScorerSupplier {
private final BooleanWeight weight;
private final Map<BooleanClause.Occur, Collection<ScorerSupplier>> subs;
private final boolean needsScores;
private final int minShouldMatch;
private long cost = -1;
Boolean2ScorerSupplier(BooleanWeight weight,
Map<Occur, Collection<ScorerSupplier>> subs,
boolean needsScores, int minShouldMatch) {
if (minShouldMatch < 0) {
throw new IllegalArgumentException("minShouldMatch must be positive, but got: " + minShouldMatch);
}
if (minShouldMatch != 0 && minShouldMatch >= subs.get(Occur.SHOULD).size()) {
throw new IllegalArgumentException("minShouldMatch must be strictly less than the number of SHOULD clauses");
}
if (needsScores == false && minShouldMatch == 0 && subs.get(Occur.SHOULD).size() > 0
&& subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() > 0) {
throw new IllegalArgumentException("Cannot pass purely optional clauses if scores are not needed");
}
if (subs.get(Occur.SHOULD).size() + subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) {
throw new IllegalArgumentException("There should be at least one positive clause");
}
this.weight = weight;
this.subs = subs;
this.needsScores = needsScores;
this.minShouldMatch = minShouldMatch;
}
private long computeCost() {
OptionalLong minRequiredCost = Stream.concat(
subs.get(Occur.MUST).stream(),
subs.get(Occur.FILTER).stream())
.mapToLong(ScorerSupplier::cost)
.min();
if (minRequiredCost.isPresent() && minShouldMatch == 0) {
return minRequiredCost.getAsLong();
} else {
final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
final long shouldCost = MinShouldMatchSumScorer.cost(
optionalScorers.stream().mapToLong(ScorerSupplier::cost),
optionalScorers.size(), minShouldMatch);
return Math.min(minRequiredCost.orElse(Long.MAX_VALUE), shouldCost);
}
}
@Override
public long cost() {
if (cost == -1) {
cost = computeCost();
}
return cost;
}
@Override
public Scorer get(boolean randomAccess) throws IOException {
// three cases: conjunction, disjunction, or mix
// pure conjunction
if (subs.get(Occur.SHOULD).isEmpty()) {
return excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), randomAccess), subs.get(Occur.MUST_NOT));
}
// pure disjunction
if (subs.get(Occur.FILTER).isEmpty() && subs.get(Occur.MUST).isEmpty()) {
return excl(opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, randomAccess), subs.get(Occur.MUST_NOT));
}
// conjunction-disjunction mix:
// we create the required and optional pieces, and then
// combine the two: if minNrShouldMatch > 0, then it's a conjunction: because the
// optional side must match. otherwise it's required + optional
if (minShouldMatch > 0) {
boolean reqRandomAccess = true;
boolean msmRandomAccess = true;
if (randomAccess == false) {
// We need to figure out whether the MUST/FILTER or the SHOULD clauses would lead the iteration
final long reqCost = Stream.concat(
subs.get(Occur.MUST).stream(),
subs.get(Occur.FILTER).stream())
.mapToLong(ScorerSupplier::cost)
.min().getAsLong();
final long msmCost = MinShouldMatchSumScorer.cost(
subs.get(Occur.SHOULD).stream().mapToLong(ScorerSupplier::cost),
subs.get(Occur.SHOULD).size(), minShouldMatch);
reqRandomAccess = reqCost > msmCost;
msmRandomAccess = msmCost > reqCost;
}
Scorer req = excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), reqRandomAccess), subs.get(Occur.MUST_NOT));
Scorer opt = opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, msmRandomAccess);
return new ConjunctionScorer(weight, Arrays.asList(req, opt), Arrays.asList(req, opt));
} else {
assert needsScores;
return new ReqOptSumScorer(
excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), randomAccess), subs.get(Occur.MUST_NOT)),
opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, true));
}
}
/** Create a new scorer for the given required clauses. Note that
* {@code requiredScoring} is a subset of {@code required} containing
* required clauses that should participate in scoring. */
private Scorer req(Collection<ScorerSupplier> requiredNoScoring, Collection<ScorerSupplier> requiredScoring, boolean randomAccess) throws IOException {
if (requiredNoScoring.size() + requiredScoring.size() == 1) {
Scorer req = (requiredNoScoring.isEmpty() ? requiredScoring : requiredNoScoring).iterator().next().get(randomAccess);
if (needsScores == false) {
return req;
}
if (requiredScoring.isEmpty()) {
// Scores are needed but we only have a filter clause
// BooleanWeight expects that calling score() is ok so we need to wrap
// to prevent score() from being propagated
return new FilterScorer(req) {
@Override
public float score() throws IOException {
return 0f;
}
@Override
public int freq() throws IOException {
return 0;
}
};
}
return req;
} else {
long minCost = Math.min(
requiredNoScoring.stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE),
requiredScoring.stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE));
List<Scorer> requiredScorers = new ArrayList<>();
List<Scorer> scoringScorers = new ArrayList<>();
for (ScorerSupplier s : requiredNoScoring) {
requiredScorers.add(s.get(randomAccess || s.cost() > minCost));
}
for (ScorerSupplier s : requiredScoring) {
Scorer scorer = s.get(randomAccess || s.cost() > minCost);
requiredScorers.add(scorer);
scoringScorers.add(scorer);
}
return new ConjunctionScorer(weight, requiredScorers, scoringScorers);
}
}
private Scorer excl(Scorer main, Collection<ScorerSupplier> prohibited) throws IOException {
if (prohibited.isEmpty()) {
return main;
} else {
return new ReqExclScorer(main, opt(prohibited, 1, false, true));
}
}
private Scorer opt(Collection<ScorerSupplier> optional, int minShouldMatch,
boolean needsScores, boolean randomAccess) throws IOException {
if (optional.size() == 1) {
return optional.iterator().next().get(randomAccess);
} else if (minShouldMatch > 1) {
final List<Scorer> optionalScorers = new ArrayList<>();
final PriorityQueue<ScorerSupplier> pq = new PriorityQueue<ScorerSupplier>(subs.get(Occur.SHOULD).size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(ScorerSupplier a, ScorerSupplier b) {
return a.cost() > b.cost();
}
};
for (ScorerSupplier scorer : subs.get(Occur.SHOULD)) {
ScorerSupplier overflow = pq.insertWithOverflow(scorer);
if (overflow != null) {
optionalScorers.add(overflow.get(true));
}
}
for (ScorerSupplier scorer : pq) {
optionalScorers.add(scorer.get(randomAccess));
}
return new MinShouldMatchSumScorer(weight, optionalScorers, minShouldMatch);
} else {
final List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier scorer : optional) {
optionalScorers.add(scorer.get(randomAccess));
}
return new DisjunctionSumScorer(weight, optionalScorers, needsScores);
}
}
}