package edu.umn.cs.recsys.cbf;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.*;
import org.apache.lucene.search.similar.MoreLikeThis;
import org.apache.lucene.store.Directory;
import org.grouplens.grapht.annotation.DefaultProvider;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.data.dao.ItemDAO;
import org.grouplens.lenskit.knn.item.ModelSize;
import org.grouplens.lenskit.knn.item.model.ItemItemModel;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.scored.ScoredIdListBuilder;
import org.grouplens.lenskit.scored.ScoredIds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
/**
* The Lucene-backed CBF model.
* @author Michael Ekstrand
*/
@DefaultProvider(LuceneModelBuilder.class)
public class LuceneItemItemModel implements ItemItemModel {
private static Logger logger = LoggerFactory.getLogger(LuceneItemItemModel.class);
private final Directory luceneDir;
private final ItemDAO itemDAO;
private final int toFetch;
private final LoadingCache<Long,List<ScoredId>> cache;
LuceneItemItemModel(Directory dir, ItemDAO idao, @ModelSize int nnbrs) {
luceneDir = dir;
itemDAO = idao;
toFetch = nnbrs;
logger.debug("initializing indexed model with size {}", nnbrs);
cache = CacheBuilder.newBuilder()
.build(new LuceneCacheLoader());
}
@Override
public LongSortedSet getItemUniverse() {
return LongUtils.packedSet(itemDAO.getItemIds());
}
@Nonnull
@Override
public List<ScoredId> getNeighbors(long item) {
try {
return cache.get(item);
} catch (ExecutionException e) {
logger.error("error fetching neighborhood", e.getCause());
throw Throwables.propagate(e.getCause());
}
}
public List<ScoredId> getNeighborsImpl(long item) {
try {
Closer closer = Closer.create();
try {
IndexReader reader = closer.register(IndexReader.open(luceneDir));
IndexSearcher idx = closer.register(new IndexSearcher(reader));;
Term term = new Term("movie", Long.toString(item));
Query tq = new TermQuery(term);
TopDocs docs = idx.search(tq, 1);
if (docs.totalHits > 1) {
logger.warn("found multiple matches for {}", item);
} else if (docs.totalHits == 0) {
logger.warn("could not find movie {}", item);
return Collections.emptyList();
}
int docid = docs.scoreDocs[0].doc;
Document doc = idx.doc(docid);
Long mid = Long.parseLong(doc.get("movie"));
if (mid != item) {
logger.error("retrieved document doesn't match ({} != {})", mid, item);
return Collections.emptyList();
}
logger.trace("movie {} has index {}", item, docid);
logger.trace("finding neighbors for movie {} ({})", item, doc.get("title"));
MoreLikeThis mlt = new MoreLikeThis(idx.getIndexReader());
mlt.setFieldNames(new String[]{"title", "genres", "tags"});
Query q = mlt.like(docid);
TopDocs results = idx.search(q, toFetch + 1);
logger.trace("index returned {} of {} similar movies",
results.scoreDocs.length, results.totalHits);
ScoredIdListBuilder builder = ScoredIds.newListBuilder();
for (ScoreDoc sd: results.scoreDocs) {
Document nbrdoc = idx.doc(sd.doc);
long id = Long.parseLong(nbrdoc.get("movie"));
if (id != item) {
builder.add(id, sd.score);
}
}
logger.trace("returning {} neighbors", builder.size());
return builder.sort(ScoredIds.scoreOrder()).build();
} catch (Throwable th) {
throw closer.rethrow(th);
} finally {
closer.close();
}
} catch (IOException e) {
throw new RuntimeException("I/O error fetching neighbors", e);
}
}
private class LuceneCacheLoader extends CacheLoader<Long,List<ScoredId>> {
@Override
public List<ScoredId> load(Long key) throws Exception {
return getNeighborsImpl(key);
}
}
}