/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.common.lucene.search.function; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.TestUtil; import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import java.util.Set; public class MinScoreScorerTests extends LuceneTestCase { private static DocIdSetIterator iterator(final int... docs) { return new DocIdSetIterator() { int i = -1; @Override public int nextDoc() throws IOException { if (i + 1 == docs.length) { return NO_MORE_DOCS; } else { return docs[++i]; } } @Override public int docID() { return i < 0 ? -1 : i == docs.length ? NO_MORE_DOCS : docs[i]; } @Override public long cost() { return docs.length; } @Override public int advance(int target) throws IOException { return slowAdvance(target); } }; } private static Scorer scorer(int maxDoc, final int[] docs, final float[] scores, final boolean twoPhase) { final DocIdSetIterator iterator = twoPhase ? DocIdSetIterator.all(maxDoc) : iterator(docs); return new Scorer(null) { public DocIdSetIterator iterator() { if (twoPhase) { return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); } else { return iterator; } } public TwoPhaseIterator twoPhaseIterator() { if (twoPhase) { return new TwoPhaseIterator(iterator) { @Override public boolean matches() throws IOException { return Arrays.binarySearch(docs, iterator.docID()) >= 0; } @Override public float matchCost() { return 10; } }; } else { return null; } } @Override public int docID() { return iterator.docID(); } @Override public float score() throws IOException { final int idx = Arrays.binarySearch(docs, docID()); return scores[idx]; } @Override public int freq() throws IOException { return 1; } }; } public void doTestRandom(boolean twoPhase) throws IOException { final int maxDoc = TestUtil.nextInt(random(), 10, 10000); final int numDocs = TestUtil.nextInt(random(), 1, maxDoc / 2); final Set<Integer> uniqueDocs = new HashSet<>(); while (uniqueDocs.size() < numDocs) { uniqueDocs.add(random().nextInt(maxDoc)); } final int[] docs = new int[numDocs]; int i = 0; for (int doc : uniqueDocs) { docs[i++] = doc; } Arrays.sort(docs); final float[] scores = new float[numDocs]; for (i = 0; i < numDocs; ++i) { scores[i] = random().nextFloat(); } Scorer scorer = scorer(maxDoc, docs, scores, twoPhase); final float minScore = random().nextFloat(); Scorer minScoreScorer = new MinScoreScorer(null, scorer, minScore); int doc = -1; while (doc != DocIdSetIterator.NO_MORE_DOCS) { final int target; if (random().nextBoolean()) { target = doc + 1; doc = minScoreScorer.iterator().nextDoc(); } else { target = doc + TestUtil.nextInt(random(), 1, 10); doc = minScoreScorer.iterator().advance(target); } int idx = Arrays.binarySearch(docs, target); if (idx < 0) { idx = -1 - idx; } while (idx < docs.length && scores[idx] < minScore) { idx += 1; } if (idx == docs.length) { assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); } else { assertEquals(docs[idx], doc); assertEquals(scores[idx], scorer.score(), 0f); } } } public void testRegularIterator() throws IOException { final int iters = atLeast(5); for (int iter = 0; iter < iters; ++iter) { doTestRandom(false); } } public void testTwoPhaseIterator() throws IOException { final int iters = atLeast(5); for (int iter = 0; iter < iters; ++iter) { doTestRandom(true); } } }