package org.wikibrain.core.dao.sql;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigValue;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.jooq.*;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.*;
import org.wikibrain.core.jooq.Tables;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageInfo;
import org.wikibrain.core.model.*;
import java.io.File;
import java.util.*;
/**
*
* A SQL database implementation of the LocalCategoryMemberDao.
*
* @author Shilad Sen
* @author Ari Weiland
*
*/
public class LocalCategoryMemberSqlDao extends AbstractSqlDao<LocalCategoryMember> implements LocalCategoryMemberDao {
private static final TableField [] INSERT_FIELDS = new TableField[] {
Tables.CATEGORY_MEMBERS.LANG_ID,
Tables.CATEGORY_MEMBERS.CATEGORY_ID,
Tables.CATEGORY_MEMBERS.ARTICLE_ID,
};
private final LocalPageDao localPageDao;
private Map<Language, CategoryGraph> graphs = new HashMap<Language, CategoryGraph>();
/**
* Only used to identify top-level categories.
*/
private final UniversalPageDao univDao;
// See https://www.wikidata.org/wiki/Q4587687
public static final int TOP_LEVEL_CONCEPT = 4587687;
// Language-specific top-level concept overrides (language -> title)
private Map<Language, Title> topLevelLangOverrides = new HashMap<Language, Title>();
public LocalCategoryMemberSqlDao(WpDataSource dataSource, LocalPageDao localArticleDao) throws DaoException {
this(dataSource, localArticleDao, null);
}
public LocalCategoryMemberSqlDao(WpDataSource dataSource, LocalPageDao localArticleDao, UniversalPageDao univDao) throws DaoException {
super(dataSource, INSERT_FIELDS, "/db/category-members");
this.localPageDao = localArticleDao;
this.univDao = univDao;
}
@Override
public void save(LocalCategoryMember member) throws DaoException {
insert(
member.getLanguage().getId(),
member.getCategoryId(),
member.getArticleId()
);
}
public void addTopLevelOverride(Language language, String topLevelTitle) {
this.topLevelLangOverrides.put(language, new Title(topLevelTitle, language));
}
@Override
public Set<LocalPage> guessTopLevelCategories(Language language) throws DaoException {
int topLevelId = -1;
Title override = topLevelLangOverrides.get(language);
if (override != null) {
System.out.println("title is " + override);
topLevelId = localPageDao.getIdByTitle(override);
if (topLevelId < 0) {
LOG.warn("top level category {} for language {} not found.", override, language);
}
}
if (topLevelId < 0) {
if (univDao == null) {
throw new DaoException("Universal dao required for top level categories.");
}
topLevelId = univDao.getLocalId(language, TOP_LEVEL_CONCEPT);
}
Set<LocalPage> result = new HashSet<LocalPage>();
if (topLevelId < 0) {
return result;
}
for (int id : getCategoryMemberIds(language, topLevelId)) {
LocalPage page = localPageDao.getById(language, id);
if (page.getNameSpace() == NameSpace.CATEGORY) {
result.add(page);
}
}
return result;
}
@Override
public void save(LocalPage category, LocalPage article) throws DaoException, WikiBrainException {
if (!graphs.isEmpty()) {
graphs.clear();
}
save(new LocalCategoryMember(category, article));
}
@Override
public LocalPage getClosestCategory(LocalPage page, Set<LocalPage> candidates, boolean weightedDistance) throws DaoException {
CategoryGraph graph = getGraph(page.getLanguage());
CategoryBfs bfs = new CategoryBfs(graph, page.getLocalId(), page.getLanguage(), Integer.MAX_VALUE, null, this);
bfs.setAddPages(false);
bfs.setExploreChildren(false);
Map<Integer, LocalPage> indexToCandidates = new HashMap<Integer, LocalPage>();
for (LocalPage c : candidates) {
indexToCandidates.put(graph.catIdToIndex(c.getLocalId()), c);
}
List<LocalPage> matches = new ArrayList<LocalPage>();
while (bfs.hasMoreResults() && matches.isEmpty()) {
CategoryBfs.BfsVisited visited = bfs.step();
for (int catId : visited.cats.keys()) {
if (indexToCandidates.containsKey(catId)) {
matches.add(indexToCandidates.get(catId));
}
}
}
if (matches.isEmpty()) {
return null;
} else {
return matches.get(new Random().nextInt(matches.size()));
}
}
static class CatCost implements Comparable<CatCost> {
LocalPage topLevelCat;
int catId;
int catIndex; // dense internal id from category graph, not an article id.
double cost;
public CatCost(LocalPage topLevelCat, int catId, int catIndex, double cost) {
this.topLevelCat = topLevelCat;
this.catId = catId;
this.catIndex = catIndex;
this.cost = cost;
}
@Override
public int compareTo(CatCost o) {
return Double.compare(cost, o.cost);
}
}
@Override
public Map<LocalPage, TIntDoubleMap> getClosestCategories(Set<LocalPage> topLevelCats) throws DaoException {
return getClosestCategories(topLevelCats, null, true);
}
/**
* For each article, identifies the closest category among the specified candidate set.
* Distance is measured using shortest path in the category graph.
*
* @param candidateCategories The categories to consider as candidates (e.g. those considered "top-level").
* @param pageIds If not null, only considers articles in the provided pageIds.
* @param weighted If true, use page-rank weighted edges so paths that traverse more
* general categories are penalized more highly.
* @return Map with candidates as keys and the articles that have them as closest category
* as values. The values are a map of article ids to distances.
* @throws DaoException
*/
@Override
public Map<LocalPage, TIntDoubleMap> getClosestCategories(Set<LocalPage> candidateCategories, TIntSet pageIds, boolean weighted) throws DaoException {
Map<LocalPage, TIntDoubleMap> results = new HashMap<LocalPage, TIntDoubleMap>();
if (candidateCategories.isEmpty()) {
return results;
}
Language language = candidateCategories.iterator().next().getLanguage();
CategoryGraph graph = getGraph(language);
int numPages = (pageIds == null) ? LanguageInfo.getByLanguage(language).getNumArticles() : pageIds.size();
PriorityQueue<CatCost> frontier = new PriorityQueue<CatCost>();
for (LocalPage p : candidateCategories) {
if (p.getLanguage() != language) throw new IllegalStateException("Category languages must be identitical");
CatCost cc = new CatCost(p, p.getLocalId(), graph.catIdToIndex(p.getLocalId()), 0.0);
if (cc.catIndex >= 0) {
frontier.add(cc);
}
results.put(p, new TIntDoubleHashMap(numPages));
}
TIntSet visited = new TIntHashSet(numPages*3); // both articles and categories
while (!frontier.isEmpty()) {
CatCost cc = frontier.poll();
if (visited.contains(cc.catId)) continue;
visited.add(cc.catId);
// Handle pages of categories
for (int pageId : graph.catPages[cc.catIndex]) {
if (!visited.contains(pageId)
&& (pageIds == null || pageIds.contains(pageId))) {
visited.add(pageId);
results.get(cc.topLevelCat).put(pageId, cc.cost);
}
}
// Descend to unexplored child categories.
for (int childIndex : graph.catChildren[cc.catIndex]) {
int childId = graph.catIndexToId(childIndex);
if (!visited.contains(childId)) {
double childCost = cc.cost + (weighted ? graph.catCosts[childIndex] : 1.0);
frontier.add(new CatCost(cc.topLevelCat, childId, childIndex, childCost));
}
}
}
return results;
}
/**
* Returns distance to specified categories for requested pages.
* Distance is measured using shortest path in the category graph.
*
* @param candidateCategories The categories to consider as candidates (e.g. those considered "top-level").
* @param pageId The article id we want to find.
* @param weighted If true, use page-rank weighted edges so paths that traverse more
* general categories are penalized more highly.
* @return Map with article ids as keys and distances to each category id as values.
* @throws DaoException
*
*/
@Override
public TIntDoubleMap getCategoryDistances(Set<LocalPage> candidateCategories, int pageId, boolean weighted) throws DaoException {
Language language = candidateCategories.iterator().next().getLanguage();
CategoryGraph graph = getGraph(language);
Map<Integer, TIntDoubleMap> results = new HashMap<Integer, TIntDoubleMap>();
// Indexes for goal categories
TIntSet goalIndexes = new TIntHashSet();
for (LocalPage p : candidateCategories) {
int i = graph.catIdToIndex(p.getLocalId());
if (i >= 0) goalIndexes.add(i);
}
// Search upwards from each page
TIntSet visited = new TIntHashSet();
// all we care about in CatCost for this search is the category index and cost
PriorityQueue<CatCost> frontier = new PriorityQueue<CatCost>();
TIntDoubleMap distances = new TIntDoubleHashMap();
for (int catId : getCategoryIds(language, pageId)) {
int i = graph.catIdToIndex(catId);
if (i >= 0) frontier.add(new CatCost(null, -1, i, graph.catCosts[i]));
}
while (!frontier.isEmpty() && distances.size() != candidateCategories.size()) {
CatCost cc = frontier.poll();
if (visited.contains(cc.catIndex)) continue;
visited.add(cc.catIndex);
if (goalIndexes.contains(cc.catIndex)) {
distances.put(graph.catIndexToId(cc.catIndex), cc.cost);
} else {
// Ascend to unexplored parent categories.
for (int parentIndex : graph.catParents[cc.catIndex]) {
if (!visited.contains(parentIndex)) {
double parentCost = cc.cost + (weighted ? graph.catCosts[parentIndex] : 1.0);
frontier.add(new CatCost(null, -1, parentIndex, parentCost));
}
}
}
}
return distances;
}
/**
* This method should generally not be used.
* @param daoFilter a set of filters to limit the search
* @return
* @throws DaoException
*/
@Override
public Iterable<LocalCategoryMember> get(DaoFilter daoFilter) throws DaoException {
DSLContext context = getJooq();
try {
Collection<Condition> conditions = new ArrayList<Condition>();
if (daoFilter.getLangIds() != null) {
conditions.add(Tables.CATEGORY_MEMBERS.LANG_ID.in(daoFilter.getLangIds()));
}
Cursor<Record> result = context.select().
from(Tables.CATEGORY_MEMBERS).
where(conditions).
limit(daoFilter.getLimitOrInfinity()).
fetchLazy(getFetchSize());
return new SimpleSqlDaoIterable<LocalCategoryMember>(result, context) {
@Override
public LocalCategoryMember transform(Record r) {
return buildLocalCategoryMember(r);
}
};
} catch (RuntimeException e) {
freeJooq(context);
throw e;
}
}
@Override
public int getCount(DaoFilter daoFilter) throws DaoException{
DSLContext context = getJooq();
try {
Collection<Condition> conditions = new ArrayList<Condition>();
if (daoFilter.getLangIds() != null) {
conditions.add(Tables.CATEGORY_MEMBERS.LANG_ID.in(daoFilter.getLangIds()));
}
return context.selectCount().
from(Tables.CATEGORY_MEMBERS).
where(conditions).
fetchOne().value1();
} finally {
freeJooq(context);
}
}
@Override
public Collection<Integer> getCategoryMemberIds(Language language, int categoryId) throws DaoException {
DSLContext context = getJooq();
try {
Result<Record> result = context.select().
from(Tables.CATEGORY_MEMBERS).
where(Tables.CATEGORY_MEMBERS.CATEGORY_ID.eq(categoryId)).
and(Tables.CATEGORY_MEMBERS.LANG_ID.eq(language.getId())).
fetch();
return extractIds(result, false);
} finally {
freeJooq(context);
}
}
@Override
public Collection<Integer> getCategoryMemberIds(LocalPage localCategory) throws DaoException {
return getCategoryMemberIds(localCategory.getLanguage(), localCategory.getLocalId());
}
@Override
public Map<Integer, LocalPage> getCategoryMembers(Language language, int categoryId) throws DaoException {
Collection<Integer> articleIds = getCategoryMemberIds(language, categoryId);
return localPageDao.getByIds(language, articleIds);
}
@Override
public Map<Integer, LocalPage> getCategoryMembers(LocalPage localCategory) throws DaoException {
Collection<Integer> articleIds = getCategoryMemberIds(localCategory);
return localPageDao.getByIds(localCategory.getLanguage(), articleIds);
}
@Override
public Collection<Integer> getCategoryIds(Language language, int articleId) throws DaoException {
DSLContext context = getJooq();
try {
Result<Record> result = context.select().
from(Tables.CATEGORY_MEMBERS).
where(Tables.CATEGORY_MEMBERS.ARTICLE_ID.eq(articleId)).
and(Tables.CATEGORY_MEMBERS.LANG_ID.eq(language.getId())).
fetch();
return extractIds(result, true);
} finally {
freeJooq(context);
}
}
@Override
public Collection<Integer> getCategoryIds(LocalPage localArticle) throws DaoException {
return getCategoryIds(localArticle.getLanguage(), localArticle.getLocalId());
}
@Override
public Map<Integer, LocalPage> getCategories(Language language, int articleId) throws DaoException {
Collection<Integer> categoryIds = getCategoryIds(language, articleId);
return localPageDao.getByIds(language, categoryIds);
}
@Override
public Map<Integer, LocalPage> getCategories(LocalPage localArticle) throws DaoException {
Collection<Integer> categoryIds = getCategoryIds(localArticle);
return localPageDao.getByIds(localArticle.getLanguage(), categoryIds);
}
@Override
public synchronized CategoryGraph getGraph(Language language) throws DaoException {
if (graphs.containsKey(language)) {
return graphs.get(language);
}
String key = "cat-graph-" + language.getLangCode();
if (cache != null) {
CategoryGraph graph = (CategoryGraph) cache.get(key, LocalPage.class, LocalCategoryMember.class);
if (graph != null) {
graphs.put(language, graph);
return graph;
}
}
LocalCategoryGraphBuilder builder = new LocalCategoryGraphBuilder(localPageDao, this);
CategoryGraph graph = builder.build(language);
cache.put(key, graph);
graphs.put(language, graph);
return graph;
}
private Collection<Integer> extractIds(Result<Record> result, boolean categoryIds) {
if (result.isEmpty()) {
return null;
}
Collection<Integer> pageIds = new ArrayList<Integer>();
for(Record record : result) {
pageIds.add(categoryIds ?
record.getValue(Tables.CATEGORY_MEMBERS.CATEGORY_ID) :
record.getValue(Tables.CATEGORY_MEMBERS.ARTICLE_ID)
);
}
return pageIds;
}
private LocalCategoryMember buildLocalCategoryMember(Record r) {
return new LocalCategoryMember(
r.getValue(Tables.CATEGORY_MEMBERS.CATEGORY_ID),
r.getValue(Tables.CATEGORY_MEMBERS.ARTICLE_ID),
Language.getById(r.getValue(Tables.CATEGORY_MEMBERS.LANG_ID))
);
}
public static class Provider extends org.wikibrain.conf.Provider<LocalCategoryMemberDao> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return LocalCategoryMemberDao.class;
}
@Override
public String getPath() {
return "dao.localCategoryMember";
}
@Override
public LocalCategoryMemberDao get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("sql")) {
return null;
}
try {
UniversalPageDao univDao = null;
MetaInfoDao metaDao = getConfigurator().get(MetaInfoDao.class);
if (metaDao.isLoaded(UniversalPage.class)) {
univDao = getConfigurator().get(UniversalPageDao.class);
}
LocalCategoryMemberSqlDao dao = new LocalCategoryMemberSqlDao(
getConfigurator().get(
WpDataSource.class,
config.getString("dataSource")),
getConfigurator().get(LocalPageDao.class),
univDao);
Config c = config.getConfig("topLevelCats");
for (Map.Entry<String, ConfigValue> e : c.entrySet()) {
dao.addTopLevelOverride(Language.getByLangCode(e.getKey()), (String) e.getValue().unwrapped());
}
String cachePath = getConfig().get().getString("dao.sqlCachePath");
File cacheDir = new File(cachePath);
if (!cacheDir.isDirectory()) {
cacheDir.mkdirs();
}
dao.useCache(cacheDir);
return dao;
} catch (DaoException e) {
throw new ConfigurationException(e);
}
}
}
}