package querqy.lucene.rewrite.prms; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.core.WhitespaceTokenizer; import org.apache.lucene.analysis.synonym.SynonymFilter; import org.apache.lucene.analysis.synonym.SynonymMap; import org.apache.lucene.codecs.Codec; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.search.*; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.Directory; import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Matchers; import org.mockito.Mockito; import querqy.lucene.rewrite.*; import querqy.lucene.rewrite.SearchFieldsAndBoosting.FieldBoostModel; import querqy.parser.WhiteSpaceQuerqyParser; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.times; public class PRMSDisjunctionMaxQueryTest extends LuceneTestCase { Similarity similarity; Similarity.SimWeight simWeight; @Override @Before public void setUp() throws Exception { super.setUp(); similarity = Mockito.mock(Similarity.class); simWeight = Mockito.mock(Similarity.SimWeight.class); Mockito.when(similarity.computeWeight(any(CollectionStatistics.class), Matchers.<TermStatistics>anyVararg())).thenReturn(simWeight); } @Test public void testGetThatFieldProbabilityRatioIsReflectedInBoost() throws Exception { ArgumentCaptor<Float> normalizeCaptor = ArgumentCaptor.forClass(Float.class); DocumentFrequencyCorrection dfc = new DocumentFrequencyCorrection(); Directory directory = newDirectory(); SynonymMap.Builder builder = new SynonymMap.Builder(true); builder.add(new CharsRef("abc"), new CharsRef("def"), true); final SynonymMap synonyms = builder.build(); Analyzer queryAnalyzer = new Analyzer() { @Override protected TokenStreamComponents createComponents(String fieldName) { WhitespaceTokenizer source = new WhitespaceTokenizer(); TokenStream result = new SynonymFilter(source, synonyms, true); return new TokenStreamComponents(source, result); } }; Analyzer indexAnalyzer = new Analyzer() { @Override protected TokenStreamComponents createComponents(String fieldName) { WhitespaceTokenizer source = new WhitespaceTokenizer(); return new TokenStreamComponents(source, source); } }; IndexWriterConfig conf = new IndexWriterConfig(indexAnalyzer); conf.setCodec(Codec.forName(TestUtil.LUCENE_CODEC)); IndexWriter indexWriter = new IndexWriter(directory, conf); PRMSFieldBoostTest.addNumDocs("f1", "abc", indexWriter, 2); PRMSFieldBoostTest.addNumDocs("f1", "def", indexWriter, 8); PRMSFieldBoostTest.addNumDocs("f2", "abc", indexWriter, 6); PRMSFieldBoostTest.addNumDocs("f2", "def", indexWriter, 4); // Within a field, all disjuncts must have the same boost factor, reflecting // the max boost factor of the disjuncts. // Given a query 'abc' and synonym expansion 'def', we get a query: // DMQ( // DMQ( // dmq1 // TQ(f1:abc), // TQ(f1:def) // ), // DMQ( // dmq2 // TQ(f2:abc), // TQ(f2:def) // ), // ) // dmq1: max boost factor is 0.8 (8 of 10 terms in f1 equal "abc") // dmq2: max boost factor is 0.6 (6 of 10 terms in f2 equal "def") // ==> the ratio of the boost factors of the disjuncts of dmq1/dmq2 must equal 0.8/0.6 indexWriter.close(); IndexReader indexReader = DirectoryReader.open(directory); IndexSearcher indexSearcher = new IndexSearcher(indexReader); indexSearcher.setSimilarity(similarity); Map<String, Float> fields = new HashMap<>(); fields.put("f1", 1f); fields.put("f2", 1f); SearchFieldsAndBoosting searchFieldsAndBoosting = new SearchFieldsAndBoosting(FieldBoostModel.PRMS, fields, fields, 0.8f); LuceneQueryBuilder queryBuilder = new LuceneQueryBuilder(dfc, queryAnalyzer, searchFieldsAndBoosting, 0.01f, null); WhiteSpaceQuerqyParser parser = new WhiteSpaceQuerqyParser(); Query query = queryBuilder.createQuery(parser.parse("abc")); dfc.finishedUserQuery(); query.createWeight(indexSearcher, true); assertTrue(query instanceof DisjunctionMaxQuery); DisjunctionMaxQuery dmq = (DisjunctionMaxQuery) query; List<Query> disjuncts = dmq.getDisjuncts(); assertEquals(2, disjuncts.size()); Query disjunct1 = disjuncts.get(0); assertTrue(disjunct1 instanceof DisjunctionMaxQuery); Query dmq1 = disjunct1.rewrite(indexReader); if (dmq1 instanceof BoostQuery) { dmq1 = ((BoostQuery) dmq1).getQuery(); } Query disjunct2 = disjuncts.get(1); assertTrue(disjunct2 instanceof DisjunctionMaxQuery); Query dmq2 = disjunct2.rewrite(indexReader); if (dmq2 instanceof BoostQuery) { dmq2 = ((BoostQuery) dmq2).getQuery(); } final Weight weight1 = dmq1.createWeight(indexSearcher, true); weight1.normalize(0.1f, 4f); final Weight weight2 = dmq2.createWeight(indexSearcher, true); weight2.normalize(0.1f, 4f); Mockito.verify(simWeight, times(4)).normalize(eq(0.1f), normalizeCaptor.capture()); final List<Float> capturedBoosts = normalizeCaptor.getAllValues(); // capturedBoosts = boosts of [dmq1.term1, dmq1.term2, dmq2.term1, dmq2.term2 ] assertEquals(capturedBoosts.get(0), capturedBoosts.get(1), 0.00001); assertEquals(capturedBoosts.get(2), capturedBoosts.get(3), 0.00001); assertEquals(0.8f / 0.6f, capturedBoosts.get(0) / capturedBoosts.get(3), 0.00001); indexReader.close(); directory.close(); indexAnalyzer.close(); queryAnalyzer.close(); } }