package querqy.lucene.rewrite.prms;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.search.*;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import querqy.lucene.rewrite.*;
import querqy.lucene.rewrite.SearchFieldsAndBoosting.FieldBoostModel;
import querqy.parser.WhiteSpaceQuerqyParser;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.times;
public class PRMSFieldBoostTest extends LuceneTestCase {
Similarity similarity;
Similarity.SimWeight simWeight;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
similarity = Mockito.mock(Similarity.class);
simWeight = Mockito.mock(Similarity.SimWeight.class);
Mockito.when(similarity.computeWeight(any(CollectionStatistics.class), Matchers.<TermStatistics>anyVararg())).thenReturn(simWeight);
}
@Test
public void testGetThatFieldProbabilityRatioIsReflectedInBoost() throws Exception {
ArgumentCaptor<Float> normalizeCaptor = ArgumentCaptor.forClass(Float.class);
DocumentFrequencyCorrection dfc = new DocumentFrequencyCorrection();
Directory directory = newDirectory();
Analyzer analyzer = new StandardAnalyzer();
IndexWriterConfig conf = new IndexWriterConfig(analyzer);
conf.setCodec(Codec.forName(TestUtil.LUCENE_CODEC));
IndexWriter indexWriter = new IndexWriter(directory, conf);
addNumDocs("f1", "abc", indexWriter, 2);
addNumDocs("f1", "def", indexWriter, 4);
addNumDocs("f2", "abc", indexWriter, 4);
addNumDocs("f2", "def", indexWriter, 2);
indexWriter.close();
IndexReader indexReader = DirectoryReader.open(directory);
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
indexSearcher.setSimilarity(similarity);
Map<String, Float> fields = new HashMap<>();
fields.put("f1", 1f);
fields.put("f2", 1f);
SearchFieldsAndBoosting searchFieldsAndBoosting = new SearchFieldsAndBoosting(FieldBoostModel.PRMS, fields, fields, 0.8f);
LuceneQueryBuilder queryBuilder = new LuceneQueryBuilder(dfc, analyzer, searchFieldsAndBoosting, 0.01f, null);
WhiteSpaceQuerqyParser parser = new WhiteSpaceQuerqyParser();
Query query = queryBuilder.createQuery(parser.parse("abc"));
dfc.finishedUserQuery();
//query.createWeight(indexSearcher, true);
assertTrue(query instanceof DisjunctionMaxQuery);
DisjunctionMaxQuery dmq = (DisjunctionMaxQuery) query;
List<Query> disjuncts = dmq.getDisjuncts();
assertEquals(2, disjuncts.size());
Query disjunct1 = disjuncts.get(0);
assertTrue(disjunct1 instanceof DependentTermQuery);
DependentTermQuery dtq1 = (DependentTermQuery) disjunct1;
Query disjunct2 = disjuncts.get(1);
assertTrue(disjunct2 instanceof DependentTermQuery);
DependentTermQuery dtq2 = (DependentTermQuery) disjunct2;
assertNotEquals(dtq1.getTerm().field(), dtq2.getTerm().field());
final Weight weight1 = dtq1.createWeight(indexSearcher, true);
final Weight weight2 = dtq2.createWeight(indexSearcher, true);
weight1.normalize(0.1f, 5f);
weight2.normalize(0.1f, 5f);
Mockito.verify(simWeight, times(2)).normalize(eq(0.1f), normalizeCaptor.capture());
final List<Float> capturedBoosts = normalizeCaptor.getAllValues();
float bf1 = capturedBoosts.get(0);
float bf2 = capturedBoosts.get(1);
assertEquals(2f, bf2 / bf1, 0.00001);
indexReader.close();
directory.close();
analyzer.close();
}
public static void addNumDocs(String fieldname, String value, IndexWriter indexWriter, int num) throws IOException {
for (int i = 0; i < num; i++) {
Document doc = new Document();
doc.add(new TextField(fieldname, value, Store.YES));
indexWriter.addDocument(doc);
}
}
}