/**
* Copyright 2014 National University of Ireland, Galway.
*
* This file is part of the SIREn project. Project and contact information:
*
* https://github.com/rdelbru/SIREn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sindice.siren.search.node;
import java.io.IOException;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostAttribute;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoringRewrite;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ByteBlockPool;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.BytesRefHash.DirectBytesStartArray;
import org.apache.lucene.util.RamUsageEstimator;
import org.sindice.siren.search.node.MultiNodeTermQuery.RewriteMethod;
import org.sindice.siren.search.node.NodeBooleanClause.Occur;
/**
* Base rewrite method that translates each term into a query, and keeps
* the scores as computed by the query.
*
* <p>
*
* Code taken from {@link ScoringRewrite} and adapted for SIREn.
*/
abstract class NodeScoringRewrite<Q extends Query> extends NodeTermCollectingRewrite<Q> {
/**
* A rewrite method that first translates each term into
* {@link NodeBooleanClause.Occur#SHOULD} clause in a
* {@link NodeBooleanQuery}, and keeps the scores as computed by the
* query. Note that typically such scores are
* meaningless to the user, and require non-trivial CPU
* to compute, so it's almost always better to use {@link
* MultiNodeTermQuery#CONSTANT_SCORE_AUTO_REWRITE_DEFAULT} instead.
*
* <p><b>NOTE</b>: This rewrite method will hit {@link
* NodeBooleanQuery.TooManyClauses} if the number of terms
* exceeds {@link NodeBooleanQuery#getMaxClauseCount}.
*
* @see #setRewriteMethod
**/
public final static NodeScoringRewrite<NodeBooleanQuery> SCORING_BOOLEAN_QUERY_REWRITE = new NodeScoringRewrite<NodeBooleanQuery>() {
@Override
protected NodeBooleanQuery getTopLevelQuery() {
return new NodeBooleanQuery();
}
@Override
protected void addClause(final NodeBooleanQuery topLevel, final Term term,
final int docCount, final float boost,
final TermContext states) {
final NodeTermQuery tq = new NodeTermQuery(term, states);
tq.setBoost(boost);
topLevel.add(tq, Occur.SHOULD);
}
@Override
protected void checkMaxClauseCount(final int count) {
if (count > BooleanQuery.getMaxClauseCount())
throw new BooleanQuery.TooManyClauses();
}
};
/**
* Like {@link #SCORING_BOOLEAN_QUERY_REWRITE} except
* scores are not computed. Instead, each matching
* document receives a constant score equal to the
* query's boost.
*
* <p><b>NOTE</b>: This rewrite method will hit {@link
* NodeBooleanQuery.TooManyClauses} if the number of terms
* exceeds {@link Siren-BooleanQuery#getMaxClauseCount}.
*
* @see #setRewriteMethod
**/
public final static RewriteMethod CONSTANT_SCORE_BOOLEAN_QUERY_REWRITE = new RewriteMethod() {
@Override
public Query rewrite(final IndexReader reader, final MultiNodeTermQuery query) throws IOException {
final NodeBooleanQuery bq = SCORING_BOOLEAN_QUERY_REWRITE.rewrite(reader, query);
// TODO: if empty boolean query return NullQuery?
if (bq.clauses().isEmpty()) {
return bq;
}
// strip the scores off
final Query result = new NodeConstantScoreQuery(bq);
result.setBoost(query.getBoost());
return result;
}
};
/**
* This method is called after every new term to check if the number of max clauses
* (e.g. in NodeBooleanQuery) is not exceeded. Throws the corresponding
* {@link RuntimeException}.
*/
protected abstract void checkMaxClauseCount(int count) throws IOException;
@Override
public Q rewrite(final IndexReader reader, final MultiNodeTermQuery query)
throws IOException {
final Q result = this.getTopLevelQuery();
final ParallelArraysTermCollector col = new ParallelArraysTermCollector();
this.collectTerms(reader, query, col);
final int size = col.terms.size();
if (size > 0) {
final int sort[] = col.terms.sort(col.termsEnum.getComparator());
final float[] boost = col.array.boost;
final TermContext[] termStates = col.array.termState;
for (int i = 0; i < size; i++) {
final int pos = sort[i];
final Term term = new Term(query.getField(), col.terms.get(pos, new BytesRef()));
assert reader.docFreq(term) == termStates[pos].docFreq();
this.addClause(result, term, termStates[pos].docFreq(), query.getBoost() * boost[pos], termStates[pos]);
}
}
return result;
}
final class ParallelArraysTermCollector extends TermCollector {
final TermFreqBoostByteStart array = new TermFreqBoostByteStart(16);
final BytesRefHash terms = new BytesRefHash(new ByteBlockPool(new ByteBlockPool.DirectAllocator()), 16, array);
TermsEnum termsEnum;
private BoostAttribute boostAtt;
@Override
public void setNextEnum(final TermsEnum termsEnum) throws IOException {
this.termsEnum = termsEnum;
this.boostAtt = termsEnum.attributes().addAttribute(BoostAttribute.class);
}
@Override
public boolean collect(final BytesRef bytes) throws IOException {
final int e = terms.add(bytes);
final TermState state = termsEnum.termState();
assert state != null;
if (e < 0 ) {
// duplicate term: update docFreq
final int pos = (-e)-1;
array.termState[pos].register(state, readerContext.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
assert array.boost[pos] == boostAtt.getBoost() : "boost should be equal in all segment TermsEnums";
} else {
// new entry: we populate the entry initially
array.boost[e] = boostAtt.getBoost();
array.termState[e] = new TermContext(topReaderContext, state, readerContext.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
NodeScoringRewrite.this.checkMaxClauseCount(terms.size());
}
return true;
}
}
/** Special implementation of BytesStartArray that keeps parallel arrays for boost and docFreq */
static final class TermFreqBoostByteStart extends DirectBytesStartArray {
float[] boost;
TermContext[] termState;
public TermFreqBoostByteStart(final int initSize) {
super(initSize);
}
@Override
public int[] init() {
final int[] ord = super.init();
boost = new float[ArrayUtil.oversize(ord.length, RamUsageEstimator.NUM_BYTES_FLOAT)];
termState = new TermContext[ArrayUtil.oversize(ord.length, RamUsageEstimator.NUM_BYTES_OBJECT_REF)];
assert termState.length >= ord.length && boost.length >= ord.length;
return ord;
}
@Override
public int[] grow() {
final int[] ord = super.grow();
boost = ArrayUtil.grow(boost, ord.length);
if (termState.length < ord.length) {
final TermContext[] tmpTermState = new TermContext[ArrayUtil.oversize(ord.length, RamUsageEstimator.NUM_BYTES_OBJECT_REF)];
System.arraycopy(termState, 0, tmpTermState, 0, termState.length);
termState = tmpTermState;
}
assert termState.length >= ord.length && boost.length >= ord.length;
return ord;
}
@Override
public int[] clear() {
boost = null;
termState = null;
return super.clear();
}
}
}