/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.search.join; import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.ComplexExplanation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.OpenBitSet; import org.apache.solr.request.SolrRequestInfo; import org.apache.solr.search.BitDocSetNative; import org.apache.solr.search.DocSet; import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SolrIndexSearcher; import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.Locale; import java.util.Set; class BlockJoinParentQuery extends Query { private final Query parentList; private final Query childQuery; private final ScoreMode scoreMode; public BlockJoinParentQuery(Query childQuery, Query parentList, ScoreMode scoreMode) { this.childQuery = childQuery; this.parentList = parentList; this.scoreMode = scoreMode; } @Override public Weight createWeight(IndexSearcher searcher) throws IOException { return new BlockJoinWeight((SolrIndexSearcher)searcher); } private class BlockJoinWeight extends Weight { private SolrIndexSearcher searcher; private Weight childWeight; private BitDocSetNative parentBitSet; public BlockJoinWeight(SolrIndexSearcher searcher) throws IOException { this.searcher = searcher; this.childWeight = childQuery.createWeight(searcher); } @Override public Query getQuery() { return BlockJoinParentQuery.this; } @Override public float getValueForNormalization() throws IOException { return childWeight.getValueForNormalization() * getBoost() * getBoost(); } @Override public void normalize(float norm, float topLevelBoost) { childWeight.normalize(norm, topLevelBoost * getBoost()); } // NOTE: acceptDocs applies (and is checked) only in the // parent document space @Override public Scorer scorer(AtomicReaderContext readerContext, Bits acceptDocs) throws IOException { // Pass scoreDocsInOrder true, topScorer false to our sub and the live docs: final Scorer childScorer = childWeight.scorer(readerContext, readerContext.reader().getLiveDocs()); if (childScorer == null) { // No matches return null; } final int firstChildDoc = childScorer.nextDoc(); if (firstChildDoc == DocIdSetIterator.NO_MORE_DOCS) { // No matches return null; } if (parentBitSet == null) { // Query q = QueryUtils.makeQueryable(parentList); Query q = parentList; parentBitSet = searcher.getDocSetBits(q); SolrRequestInfo.getRequestInfo().addCloseHook(parentBitSet); // TODO: a better place to decref this } BitSetSlice parentBits = new BitSetSlice(parentBitSet, readerContext.docBase, readerContext.reader().maxDoc()); return new BlockJoinScorer(this, childScorer, parentBits, firstChildDoc, scoreMode, acceptDocs); } @Override public Explanation explain(AtomicReaderContext context, int doc) throws IOException { BlockJoinScorer scorer = (BlockJoinScorer) scorer(context, context.reader().getLiveDocs()); if (scorer != null && scorer.advance(doc) == doc) { return scorer.explain(context.docBase); } return new ComplexExplanation(false, 0.0f, "Not a match"); } @Override public boolean scoresDocsOutOfOrder() { return false; } } static class BlockJoinScorer extends Scorer { private final Scorer childScorer; private final BitSetSlice parentBits; private final ScoreMode scoreMode; private final Bits acceptDocs; private int parentDoc = -1; private int prevParentDoc; private float parentScore; private int parentFreq; private int nextChildDoc; private int[] pendingChildDocs = new int[5]; private float[] pendingChildScores; private int childDocUpto; public BlockJoinScorer(Weight weight, Scorer childScorer, BitSetSlice parentBits, int firstChildDoc, ScoreMode scoreMode, Bits acceptDocs) { super(weight); this.parentBits = parentBits; this.childScorer = childScorer; this.scoreMode = scoreMode; this.acceptDocs = acceptDocs; if (scoreMode != ScoreMode.None) { pendingChildScores = new float[5]; } nextChildDoc = firstChildDoc; } @Override public Collection<ChildScorer> getChildren() { return Collections.singleton(new ChildScorer(childScorer, "BLOCK_JOIN")); } @Override public int nextDoc() throws IOException { // Loop until we hit a parentDoc that's accepted while (true) { if (nextChildDoc == NO_MORE_DOCS) { return parentDoc = NO_MORE_DOCS; } // Gather all children sharing the same parent as // nextChildDoc parentDoc = parentBits.nextSetBit(nextChildDoc); // Parent & child docs are supposed to be // orthogonal: // TODO: think about relaxing this if (nextChildDoc == parentDoc) { throw new IllegalStateException("child query must only match non-parent docs, but parent docID=" + nextChildDoc + " matched childScorer=" + childScorer.getClass()); } assert parentDoc != -1; if (acceptDocs != null && !acceptDocs.get(parentDoc)) { // Parent doc not accepted; skip child docs until // we hit a new parent doc: do { nextChildDoc = childScorer.nextDoc(); } while (nextChildDoc < parentDoc); // Parent & child docs are supposed to be // orthogonal: if (nextChildDoc == parentDoc) { throw new IllegalStateException("child query must only match non-parent docs, but parent docID=" + nextChildDoc + " matched childScorer=" + childScorer.getClass()); } continue; } float totalScore = 0; float maxScore = Float.NEGATIVE_INFINITY; childDocUpto = 0; parentFreq = 0; do { if (pendingChildDocs.length == childDocUpto) { pendingChildDocs = ArrayUtil.grow(pendingChildDocs); } if (scoreMode != ScoreMode.None && pendingChildScores.length == childDocUpto) { pendingChildScores = ArrayUtil.grow(pendingChildScores); } pendingChildDocs[childDocUpto] = nextChildDoc; if (scoreMode != ScoreMode.None) { final float childScore = childScorer.score(); final int childFreq = childScorer.freq(); pendingChildScores[childDocUpto] = childScore; maxScore = Math.max(childScore, maxScore); totalScore += childScore; parentFreq += childFreq; } childDocUpto++; nextChildDoc = childScorer.nextDoc(); } while (nextChildDoc < parentDoc); // Parent & child docs are supposed to be // orthogonal: if (nextChildDoc == parentDoc) { throw new IllegalStateException("child query must only match non-parent docs, but parent docID=" + nextChildDoc + " matched childScorer=" + childScorer.getClass()); } switch(scoreMode) { case Avg: parentScore = totalScore / childDocUpto; break; case Max: parentScore = maxScore; break; case Total: parentScore = totalScore; break; case None: break; } return parentDoc; } } @Override public int docID() { return parentDoc; } @Override public float score() throws IOException { return parentScore; } @Override public int freq() { return parentFreq; } @Override public int advance(int parentTarget) throws IOException { if (parentTarget == NO_MORE_DOCS) { return parentDoc = NO_MORE_DOCS; } if (parentTarget == 0) { // Callers should only be passing in a docID from // the parent space, so this means this parent // has no children (it got docID 0), so it cannot // possibly match. We must handle this case // separately otherwise we pass invalid -1 to // prevSetBit below: return nextDoc(); } prevParentDoc = parentBits.prevSetBit(parentTarget-1); assert prevParentDoc >= parentDoc; if (prevParentDoc > nextChildDoc) { nextChildDoc = childScorer.advance(prevParentDoc); } // Parent & child docs are supposed to be orthogonal: if (nextChildDoc == prevParentDoc) { throw new IllegalStateException("child query must only match non-parent docs, but parent docID=" + nextChildDoc + " matched childScorer=" + childScorer.getClass()); } return nextDoc(); } public Explanation explain(int docBase) throws IOException { int start = docBase + prevParentDoc + 1; // +1 b/c prevParentDoc is previous parent doc int end = docBase + parentDoc - 1; // -1 b/c parentDoc is parent doc return new ComplexExplanation( true, score(), String.format(Locale.ROOT, "Score based on child doc range from %d to %d", start, end) ); } @Override public long cost() { return childScorer.cost(); } } @Override public void extractTerms(Set<Term> terms) { childQuery.extractTerms(terms); } @Override public Query rewrite(IndexReader reader) throws IOException { Query childRewrite = childQuery.rewrite(reader); Query parentRewrite = parentList.rewrite(reader); if (childRewrite != childQuery || parentRewrite != parentList) { Query rewritten = new BlockJoinParentQuery(childRewrite, parentRewrite, scoreMode); rewritten.setBoost(getBoost()); return rewritten; } else { return this; } } @Override public String toString(String field) { return "{!parent which='" + parentList + "'}" + childQuery; } @Override public boolean equals(Object o) { if (!(o instanceof BlockJoinParentQuery)) return false; final BlockJoinParentQuery other = (BlockJoinParentQuery) o; return childQuery.equals(other.childQuery) && parentList.equals(other.parentList) && scoreMode == other.scoreMode && getBoost() == other.getBoost(); } @Override public int hashCode() { int hash = 0x4c310a59; hash = hash*29 + childQuery.hashCode(); hash = hash*29 + parentList.hashCode(); hash = hash*29 + scoreMode.hashCode(); hash = hash*29 + Float.floatToRawIntBits(getBoost()); return hash; } }