package org.wikibrain.sr.category; import com.typesafe.config.Config; import gnu.trove.set.TIntSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.LocalCategoryMemberDao; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.dao.sql.CategoryBfs; import org.wikibrain.core.lang.Language; import org.wikibrain.core.model.CategoryGraph; import org.wikibrain.sr.*; import org.wikibrain.sr.BaseSRMetric; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.disambig.Disambiguator; import java.util.ArrayList; import java.util.Map; /** * <p>This metric is an enhanced variant of Stube and Ponzetto's WikiRelate.</p> * * <p>This class makes two enhancements to the original SR metric that improve * the accuracy and efficiency of the algorithm calculations. First, this class * uses page-rank weighted edge weights for the category graph to distinguish * between categories with different levels of specificity. Second, this class * uses bidirectional search (instead of vanilla breadth-first search) for the * similarity() method.</p> * * @author Matt Lesicko * @author Shilad Sen */ public class CategoryGraphSimilarity extends BaseSRMetric { private static final Logger LOG = LoggerFactory.getLogger(CategoryGraphSimilarity.class); private final CategoryGraph graph; LocalCategoryMemberDao catHelper; public CategoryGraphSimilarity(String name, Language language, LocalPageDao pageDao, Disambiguator disambiguator, LocalCategoryMemberDao categoryMemberDao) throws DaoException { super(name, language,pageDao,disambiguator); this.catHelper=categoryMemberDao; this.graph = categoryMemberDao.getGraph(language); } public double distanceToScore(double distance) { return distanceToScore(graph, distance); } public static double distanceToScore(CategoryGraph graph, double distance) { distance = Math.max(distance, graph.minCost); assert(graph.minCost < 1.0); // if this isn't true, direction is flipped. if (Double.isInfinite(distance)){ return 0.0; } return (Math.log(distance) / Math.log(graph.minCost)); } @Override public SRConfig getConfig() { SRConfig config = new SRConfig(); config.minScore = -1.0f; config.maxScore = +1.0f; return config; } /** * Some languages do not have categories. Don't choke on them! * @param dataset * @throws DaoException */ @Override public synchronized void trainSimilarity(Dataset dataset) throws DaoException { try { super.trainSimilarity(dataset); } catch (Exception e) { LOG.warn("Training of sr metric similarity " + getName() + " failed, disabling it.", e); } } /** * Some languages do not have categories. Don't choke on them! */ @Override public synchronized void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) { try { super.trainMostSimilar(dataset, numResults, validIds); } catch (Exception e) { LOG.warn("Training of sr metric mostSimilar " + getName() + " failed, disabling it.", e); } } @Override public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException { if (!similarityIsTrained()) { return new SRResult(0.0); } CategoryBfs bfs1 = new CategoryBfs(graph,pageId1,getLanguage(), Integer.MAX_VALUE, null, catHelper); CategoryBfs bfs2 = new CategoryBfs(graph,pageId2,getLanguage(), Integer.MAX_VALUE, null, catHelper); bfs1.setAddPages(false); bfs1.setExploreChildren(false); bfs2.setAddPages(false); bfs2.setExploreChildren(false); double shortestDistance = Double.POSITIVE_INFINITY; double maxDist1 = 0; double maxDist2 = 0; // Note that all the category ids below are dense indexes in [0, numCategories). // The mapping is determined by the graph. while ((bfs1.hasMoreResults() || bfs2.hasMoreResults()) && (maxDist1 + maxDist2 < shortestDistance)) { // Search from d1 while (bfs1.hasMoreResults() && (maxDist1 <= maxDist2 || !bfs2.hasMoreResults())) { CategoryBfs.BfsVisited visited = bfs1.step(); for (int catId : visited.cats.keys()) { if (bfs2.hasCategoryDistanceForIndex(catId)) { double d = bfs1.getCategoryDistanceForIndex(catId) + bfs2.getCategoryDistanceForIndex(catId) - graph.catCosts[catId]; // counted twice shortestDistance = Math.min(d, shortestDistance); } } maxDist1 = Math.max(maxDist1, visited.maxCatDistance()); } // Search from d2 while (bfs2.hasMoreResults() && (maxDist2 <= maxDist1 || !bfs1.hasMoreResults())) { CategoryBfs.BfsVisited visited = bfs2.step(); for (int catId : visited.cats.keys()) { if (bfs1.hasCategoryDistanceForIndex(catId)) { double d = bfs1.getCategoryDistanceForIndex(catId) + bfs2.getCategoryDistanceForIndex(catId) + 0 - graph.catCosts[catId]; // counted twice; shortestDistance = Math.min(d, shortestDistance); } } maxDist2 = Math.max(maxDist2, visited.maxCatDistance()); } } return new SRResult(distanceToScore(shortestDistance)); } @Override public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException { if (!mostSimilarIsTrained()) { return new SRResultList(0); } SRResultList results = getCachedMostSimilar(pageId, maxResults, validIds); if (results != null) { return results; } CategoryBfs bfs = new CategoryBfs(graph,pageId,getLanguage(), maxResults, validIds, catHelper); while (bfs.hasMoreResults()) { bfs.step(); } results = new SRResultList(bfs.getPageDistances().size()); int i = 0; for (int pageId2: bfs.getPageDistances().keys()) { results.set(i++, pageId2, distanceToScore(bfs.getPageDistances().get(pageId2))); } results.sortDescending(); return normalize(results); } public static class Provider extends org.wikibrain.conf.Provider<SRMetric> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return SRMetric.class; } @Override public String getPath() { return "sr.metric.local"; } @Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("categorygraphsimilarity")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("LocalCategoryGraphBuilder requires 'language' runtime parameter."); } Language language = Language.getByLangCode(runtimeParams.get("language")); CategoryGraphSimilarity sr = null; try { sr = new CategoryGraphSimilarity( name, language, getConfigurator().get(LocalPageDao.class,config.getString("pageDao")), getConfigurator().get(Disambiguator.class,config.getString("disambiguator"), "language", language.getLangCode()), getConfigurator().get(LocalCategoryMemberDao.class,config.getString("categoryMemberDao")) ); } catch (DaoException e) { throw new ConfigurationException(e); } configureBase(getConfigurator(), sr, config); return sr; } } }