package org.apache.lucene.search.join; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.io.IOException; import java.util.Locale; import java.util.Set; import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.DocsEnum; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ComplexExplanation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefHash; import org.apache.lucene.util.FixedBitSet; class TermsIncludingScoreQuery extends Query { final String field; final boolean multipleValuesPerDocument; final BytesRefHash terms; final float[] scores; final int[] ords; final Query originalQuery; final Query unwrittenOriginalQuery; TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, Query originalQuery) { this.field = field; this.multipleValuesPerDocument = multipleValuesPerDocument; this.terms = terms; this.scores = scores; this.originalQuery = originalQuery; this.ords = terms.sort(BytesRef.getUTF8SortedAsUnicodeComparator()); this.unwrittenOriginalQuery = originalQuery; } private TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, int[] ords, Query originalQuery, Query unwrittenOriginalQuery) { this.field = field; this.multipleValuesPerDocument = multipleValuesPerDocument; this.terms = terms; this.scores = scores; this.originalQuery = originalQuery; this.ords = ords; this.unwrittenOriginalQuery = unwrittenOriginalQuery; } @Override public String toString(String string) { return String.format(Locale.ROOT, "TermsIncludingScoreQuery{field=%s;originalQuery=%s}", field, unwrittenOriginalQuery); } @Override public void extractTerms(Set<Term> terms) { originalQuery.extractTerms(terms); } @Override public Query rewrite(IndexReader reader) throws IOException { final Query originalQueryRewrite = originalQuery.rewrite(reader); if (originalQueryRewrite != originalQuery) { Query rewritten = new TermsIncludingScoreQuery(field, multipleValuesPerDocument, terms, scores, ords, originalQueryRewrite, originalQuery); rewritten.setBoost(getBoost()); return rewritten; } else { return this; } } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (!super.equals(obj)) { return false; } if (getClass() != obj.getClass()) { return false; } TermsIncludingScoreQuery other = (TermsIncludingScoreQuery) obj; if (!field.equals(other.field)) { return false; } if (!unwrittenOriginalQuery.equals(other.unwrittenOriginalQuery)) { return false; } return true; } @Override public int hashCode() { final int prime = 31; int result = super.hashCode(); result += prime * field.hashCode(); result += prime * unwrittenOriginalQuery.hashCode(); return result; } @Override public Weight createWeight(IndexSearcher searcher) throws IOException { final Weight originalWeight = originalQuery.createWeight(searcher); return new Weight() { private TermsEnum segmentTermsEnum; @Override public Explanation explain(AtomicReaderContext context, int doc) throws IOException { SVInnerScorer scorer = (SVInnerScorer) bulkScorer(context, false, null); if (scorer != null) { return scorer.explain(doc); } return new ComplexExplanation(false, 0.0f, "Not a match"); } @Override public boolean scoresDocsOutOfOrder() { // We have optimized impls below if we are allowed // to score out-of-order: return true; } @Override public Query getQuery() { return TermsIncludingScoreQuery.this; } @Override public float getValueForNormalization() throws IOException { return originalWeight.getValueForNormalization() * TermsIncludingScoreQuery.this.getBoost() * TermsIncludingScoreQuery.this.getBoost(); } @Override public void normalize(float norm, float topLevelBoost) { originalWeight.normalize(norm, topLevelBoost * TermsIncludingScoreQuery.this.getBoost()); } @Override public Scorer scorer(AtomicReaderContext context, Bits acceptDocs) throws IOException { Terms terms = context.reader().terms(field); if (terms == null) { return null; } // what is the runtime...seems ok? final long cost = context.reader().maxDoc() * terms.size(); segmentTermsEnum = terms.iterator(segmentTermsEnum); if (multipleValuesPerDocument) { return new MVInOrderScorer(this, acceptDocs, segmentTermsEnum, context.reader().maxDoc(), cost); } else { return new SVInOrderScorer(this, acceptDocs, segmentTermsEnum, context.reader().maxDoc(), cost); } } @Override public BulkScorer bulkScorer(AtomicReaderContext context, boolean scoreDocsInOrder, Bits acceptDocs) throws IOException { if (scoreDocsInOrder) { return super.bulkScorer(context, scoreDocsInOrder, acceptDocs); } else { Terms terms = context.reader().terms(field); if (terms == null) { return null; } // what is the runtime...seems ok? final long cost = context.reader().maxDoc() * terms.size(); segmentTermsEnum = terms.iterator(segmentTermsEnum); // Optimized impls that take advantage of docs // being allowed to be out of order: if (multipleValuesPerDocument) { return new MVInnerScorer(this, acceptDocs, segmentTermsEnum, context.reader().maxDoc(), cost); } else { return new SVInnerScorer(this, acceptDocs, segmentTermsEnum, cost); } } } }; } // This impl assumes that the 'join' values are used uniquely per doc per field. Used for one to many relations. class SVInnerScorer extends BulkScorer { final BytesRef spare = new BytesRef(); final Bits acceptDocs; final TermsEnum termsEnum; final long cost; int upto; DocsEnum docsEnum; DocsEnum reuse; int scoreUpto; int doc; SVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, long cost) { this.acceptDocs = acceptDocs; this.termsEnum = termsEnum; this.cost = cost; this.doc = -1; } @Override public boolean score(Collector collector, int max) throws IOException { FakeScorer fakeScorer = new FakeScorer(); collector.setScorer(fakeScorer); if (doc == -1) { doc = nextDocOutOfOrder(); } while(doc < max) { fakeScorer.doc = doc; fakeScorer.score = scores[ords[scoreUpto]]; collector.collect(doc); doc = nextDocOutOfOrder(); } return doc != DocsEnum.NO_MORE_DOCS; } int nextDocOutOfOrder() throws IOException { while (true) { if (docsEnum != null) { int docId = docsEnumNextDoc(); if (docId == DocIdSetIterator.NO_MORE_DOCS) { docsEnum = null; } else { return doc = docId; } } if (upto == terms.size()) { return doc = DocIdSetIterator.NO_MORE_DOCS; } scoreUpto = upto; if (termsEnum.seekExact(terms.get(ords[upto++], spare))) { docsEnum = reuse = termsEnum.docs(acceptDocs, reuse, DocsEnum.FLAG_NONE); } } } protected int docsEnumNextDoc() throws IOException { return docsEnum.nextDoc(); } private Explanation explain(int target) throws IOException { int docId; do { docId = nextDocOutOfOrder(); if (docId < target) { int tempDocId = docsEnum.advance(target); if (tempDocId == target) { docId = tempDocId; break; } } else if (docId == target) { break; } docsEnum = null; // goto the next ord. } while (docId != DocIdSetIterator.NO_MORE_DOCS); return new ComplexExplanation(true, scores[ords[scoreUpto]], "Score based on join value " + termsEnum.term().utf8ToString()); } } // This impl that tracks whether a docid has already been emitted. This check makes sure that docs aren't emitted // twice for different join values. This means that the first encountered join value determines the score of a document // even if other join values yield a higher score. class MVInnerScorer extends SVInnerScorer { final FixedBitSet alreadyEmittedDocs; MVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, int maxDoc, long cost) { super(weight, acceptDocs, termsEnum, cost); alreadyEmittedDocs = new FixedBitSet(maxDoc); } @Override protected int docsEnumNextDoc() throws IOException { while (true) { int docId = docsEnum.nextDoc(); if (docId == DocIdSetIterator.NO_MORE_DOCS) { return docId; } if (!alreadyEmittedDocs.getAndSet(docId)) { return docId;//if it wasn't previously set, return it } } } } class SVInOrderScorer extends Scorer { final DocIdSetIterator matchingDocsIterator; final float[] scores; final long cost; int currentDoc = -1; SVInOrderScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { super(weight); FixedBitSet matchingDocs = new FixedBitSet(maxDoc); this.scores = new float[maxDoc]; fillDocsAndScores(matchingDocs, acceptDocs, termsEnum); this.matchingDocsIterator = matchingDocs.iterator(); this.cost = cost; } protected void fillDocsAndScores(FixedBitSet matchingDocs, Bits acceptDocs, TermsEnum termsEnum) throws IOException { BytesRef spare = new BytesRef(); DocsEnum docsEnum = null; for (int i = 0; i < terms.size(); i++) { if (termsEnum.seekExact(terms.get(ords[i], spare))) { docsEnum = termsEnum.docs(acceptDocs, docsEnum, DocsEnum.FLAG_NONE); float score = TermsIncludingScoreQuery.this.scores[ords[i]]; for (int doc = docsEnum.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = docsEnum.nextDoc()) { matchingDocs.set(doc); // In the case the same doc is also related to a another doc, a score might be overwritten. I think this // can only happen in a many-to-many relation scores[doc] = score; } } } } @Override public float score() throws IOException { return scores[currentDoc]; } @Override public int freq() throws IOException { return 1; } @Override public int docID() { return currentDoc; } @Override public int nextDoc() throws IOException { return currentDoc = matchingDocsIterator.nextDoc(); } @Override public int advance(int target) throws IOException { return currentDoc = matchingDocsIterator.advance(target); } @Override public long cost() { return cost; } } // This scorer deals with the fact that a document can have more than one score from multiple related documents. class MVInOrderScorer extends SVInOrderScorer { MVInOrderScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { super(weight, acceptDocs, termsEnum, maxDoc, cost); } @Override protected void fillDocsAndScores(FixedBitSet matchingDocs, Bits acceptDocs, TermsEnum termsEnum) throws IOException { BytesRef spare = new BytesRef(); DocsEnum docsEnum = null; for (int i = 0; i < terms.size(); i++) { if (termsEnum.seekExact(terms.get(ords[i], spare))) { docsEnum = termsEnum.docs(acceptDocs, docsEnum, DocsEnum.FLAG_NONE); float score = TermsIncludingScoreQuery.this.scores[ords[i]]; for (int doc = docsEnum.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = docsEnum.nextDoc()) { // I prefer this: /*if (scores[doc] < score) { scores[doc] = score; matchingDocs.set(doc); }*/ // But this behaves the same as MVInnerScorer and only then the tests will pass: if (!matchingDocs.get(doc)) { scores[doc] = score; matchingDocs.set(doc); } } } } } } }