package org.wikibrain.core.dao.sql;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalCategoryMemberDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.CategoryGraph;
import org.wikibrain.core.model.NameSpace;
import java.util.Collection;
import java.util.PriorityQueue;
/**
* Conducts Dijkstra on the category hierarchy from a starting document.
* Pages attached to visited categories are recorded, and iterations stop
* when a certain number of unique pages have been discovered.
* @author Shilad Sen
* @author Matt Lesicko
*/
public class CategoryBfs {
private CategoryGraph graph;
private int startPage;
private int maxResults;
private LocalCategoryMemberDao categoryMemberDao;
private Language language;
/**
* Observed distances to visited categories.
*/
private TIntDoubleHashMap catDistances = new TIntDoubleHashMap();
/**
* Observed distances to visited pages.
*/
private TIntDoubleHashMap pageDistances = new TIntDoubleHashMap();
/**
* Categories that have been seen, but not visited.
*/
private PriorityQueue<CategoryDistance> openCats = new PriorityQueue<CategoryDistance>();
/**
* Results of the current iteration.
*/
private BfsVisited visited = new BfsVisited();
/**
* If true, tracks pages visited along the way.
*/
public boolean addPages = true;
/**
* If true, explore paths that travel up to an ancestor and back down to a descendant.
* If false, only travel upwards.
*/
public boolean exploreChildren = true;
/**
* Wikipedia ids that can be traversed in the result set.
*/
private TIntSet validWpIds;
int numSteps = 0;
public CategoryBfs(CategoryGraph graph, int startCatId, Language language, int maxResults, TIntSet validWpIds, LocalCategoryMemberDao categoryMemberDao) throws DaoException {
this(graph, startCatId, NameSpace.ARTICLE, language, maxResults, validWpIds, categoryMemberDao, (byte)+1);
}
public CategoryBfs(CategoryGraph graph, int startId, NameSpace startNamespace, Language language, int maxResults, TIntSet validWpIds, LocalCategoryMemberDao categoryMemberDao, int direction) throws DaoException {
this.startPage = startId;
this.maxResults = maxResults;
this.graph = graph;
this.validWpIds = validWpIds;
this.categoryMemberDao = categoryMemberDao;
this.language = language;
pageDistances.put(startPage, 0.000000);
if (startNamespace == NameSpace.ARTICLE) {
Collection<Integer> cats = categoryMemberDao.getCategoryIds(language, startId);
if (cats!=null){
for (int catId : cats) {
int ci = graph.catIdToIndex(catId);
if (ci >= 0) {
openCats.add(new CategoryDistance(ci, graph.cats[ci], graph.catCosts[ci], (byte)direction));
}
}
}
} else if (startNamespace == NameSpace.CATEGORY) {
int ci = graph.catIdToIndex(startId);
if (ci >= 0) {
openCats.add(new CategoryDistance(ci, graph.cats[ci], 0.000000001, (byte)direction));
}
} else {
throw new IllegalArgumentException();
}
}
public void setAddPages(boolean addPages) {
this.addPages = addPages;
}
public void setExploreChildren(boolean exploreChildren) {
this.exploreChildren = exploreChildren;
}
public boolean hasMoreResults() {
return openCats.size() > 0 && pageDistances.size() < maxResults;
}
/**
* Runs one step of Dijjkstra by visiting the closest unvisited category.
* @return A BfsVisited object that captures all pages and categories visited in the step.
*/
public BfsVisited step() {
numSteps++;
visited.clear();
if (!hasMoreResults()) {
return visited;
}
CategoryDistance cs;
do {
cs = openCats.poll();
} while (hasMoreResults() && catDistances.contains(cs.getCatIndex()));
visited.cats.put(cs.getCatIndex(), cs.getDistance());
catDistances.put(cs.getCatIndex(), cs.getDistance());
// System.out.println("visited " + cs.toString());
// add directly linked pages
if (addPages) {
for (int i : graph.catPages[cs.getCatIndex()]) {
if (validWpIds != null && !validWpIds.contains(i)) {
continue;
}
if (!pageDistances.containsKey(i) || pageDistances.get(i) > cs.getDistance()) {
pageDistances.put(i, cs.getDistance());
visited.pages.put(i, cs.getDistance());
}
if (pageDistances.size() >= maxResults) {
break; // may be an issue for huge categories
}
}
}
// next steps downwards
if (exploreChildren) {
for (int i : graph.catChildren[cs.getCatIndex()]) {
if (!catDistances.containsKey(i)) {
double d = cs.getDistance() + graph.catCosts[i];
openCats.add(new CategoryDistance(i, graph.cats[i], d, (byte)-1));
}
}
}
// next steps upwards (if still possible)
if (cs.getDirection() == +1) {
for (int i : graph.catParents[cs.getCatIndex()]) {
if (!catDistances.containsKey(i)) {
double d = cs.getDistance() + graph.catCosts[i];
openCats.add(new CategoryDistance(i, graph.cats[i], d, (byte)+1));
}
}
}
return visited;
}
public TIntDoubleHashMap getPageDistances() {
return pageDistances;
}
public boolean hasPageDistance(int pageId) {
return pageDistances.containsKey(pageId);
}
public double getPageDistance(int pageId) {
return pageDistances.get(pageId);
}
public boolean hasCategoryDistanceForIndex(int categoryId) {
return catDistances.containsKey(categoryId);
}
public boolean hasCategoryDistance(int pageId) {
return catDistances.containsKey(graph.catIdToIndex(pageId));
}
public double getCategoryDistance(int categoryId) {
return catDistances.get(graph.catIdToIndex(categoryId));
}
public double getCategoryDistanceForIndex(int catIndex) {
return catDistances.get(catIndex);
}
public class BfsVisited {
public TIntDoubleHashMap pages = new TIntDoubleHashMap();
public TIntDoubleHashMap cats = new TIntDoubleHashMap();
public void clear() { pages.clear(); cats.clear(); }
public double maxPageDistance() { return max(pages.values()); }
public double maxCatDistance() { return max(cats.values()); }
}
private double max(double []A) {
double max = Double.NEGATIVE_INFINITY;
for (double x : A) {
if (x > max) max = x;
}
return max;
}
}