package org.wikibrain.lucene;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.Filter;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.DocIdBitSet;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A lucene filter that only includes a specific set of Wikipedia ids.
* The constructor is EXPENSIVE, so they should be reused.
* TODO: Perform a search when there are relatively few wpIds.
*/
public class WpIdFilter extends Filter {
private static final Logger LOG = LoggerFactory.getLogger(WpIdFilter.class);
private int[] wpIds;
private Map<AtomicReader, int[]> allowedLuceneIds = new HashMap<AtomicReader, int[]>();
public WpIdFilter(int wpIds[]) throws IOException {
this.wpIds = wpIds;
}
@Override
public DocIdSet getDocIdSet(AtomicReaderContext context, Bits acceptDocs) throws IOException {
BitSet bits = new BitSet();
int i = 0;
for (int id : getAllowedLuceneIds(context)) {
if (acceptDocs == null || acceptDocs.get(id)) {
bits.set(id);
i++;
}
}
int n = 0;
for (int id = 0; id < bits.length(); id++) {
if (bits.get(id)) n++;
}
// LOG.info("bit size=" + bits.size() + " set=" + n + " compared to " + luceneIds.length);
return new DocIdBitSet(bits);
}
private synchronized int[] getAllowedLuceneIds(AtomicReaderContext context) throws IOException {
AtomicReader reader = context.reader();
if (allowedLuceneIds.containsKey(reader)) {
return allowedLuceneIds.get(reader);
}
LOG.debug("building WpId filter for " + wpIds.length + " ids with hash " + Arrays.hashCode(wpIds));
TIntSet wpIdSet = new TIntHashSet(wpIds);
TIntSet luceneIdSet = new TIntHashSet();
Set<String> fields = new HashSet<String>(Arrays.asList(LuceneOptions.LOCAL_ID_FIELD_NAME));
for (int i = 0; i < reader.numDocs(); i++) {
Document d = reader.document(i, fields);
int wpId = Integer.valueOf(d.get(LuceneOptions.LOCAL_ID_FIELD_NAME));
if (wpIdSet.contains(wpId)) {
luceneIdSet.add(i);
}
}
int luceneIds[] = luceneIdSet.toArray();
LOG.debug("WpId filter matched " + luceneIds.length + " ids.");
allowedLuceneIds.put(reader, luceneIds);
return luceneIds;
}
}