package org.wikibrain.sr.wikify; import com.typesafe.config.Config; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.dao.*; import org.wikibrain.core.lang.Language; import org.wikibrain.core.model.LocalLink; import org.wikibrain.core.model.RawPage; import org.wikibrain.core.nlp.StringTokenizer; import org.wikibrain.core.nlp.Token; import org.wikibrain.phrases.*; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.utils.Leaderboard; import org.wikibrain.sr.vector.FeatureFilter; import org.wikibrain.utils.Scoreboard; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Wikifier based on Doug Downey's approach described in * * http://web-ngram.research.microsoft.com/erd2014/Docs/submissions/erd14_submission_24.pdf * @author Shilad Sen */ public class WebSailWikifier implements Wikifier { private static final Logger LOG = LoggerFactory.getLogger(WebSailWikifier.class); /** * TODO: Make this configurable */ private int numTrainingLinks = 50000; private final Wikifier identityWikifier; private final SRMetric metric; private final LinkProbabilityDao linkProbDao; private final Language language; private final PhraseTokenizer phraseTokenizer; private final LocalLinkDao linkDao; private final PhraseAnalyzerDao phraseDao; private final RawPageDao rawPageDao; private double desiredLinkRecall = 0.98; private double minLinkProbability = 0.01; private double minFinalScore = 0.001; public WebSailWikifier(Wikifier identityWikifier, RawPageDao rawPageDao, LocalLinkDao linkDao, LinkProbabilityDao linkProbDao, PhraseAnalyzerDao phraseDao, SRMetric metric) throws DaoException { this.identityWikifier = identityWikifier; this.metric = metric; this.language = metric.getLanguage(); this.linkDao = linkDao; this.linkProbDao = linkProbDao; this.rawPageDao = rawPageDao; this.phraseDao = phraseDao; this.phraseTokenizer = new PhraseTokenizer(linkProbDao); learnMinLinkProbability(); } public void setDesiredLinkRecall(double recall) throws DaoException { this.desiredLinkRecall = recall; this.learnMinLinkProbability(); } public void setMinLinkProbability(double minProb) { this.minLinkProbability = minProb; } private void learnMinLinkProbability() throws DaoException { if (!linkProbDao.isBuilt()) { linkProbDao.build(); } LOG.info("Learning minimum link probability"); TDoubleList probs = new TDoubleArrayList(); DaoFilter filter = new DaoFilter() .setLanguages(language) .setHasDest(true) .setLimit(numTrainingLinks); for (LocalLink ll : linkDao.get(filter)) { if (ll.getDestId() < 0) throw new IllegalStateException(); double p = linkProbDao.getLinkProbability(ll.getAnchorText()); probs.add(p); } probs.sort(); probs.reverse(); int index = (int)(desiredLinkRecall * probs.size()); minLinkProbability = (index >= probs.size()) ? 0.0 : probs.get(index); LOG.info("Set minimum link probability to " + minLinkProbability + " to achieve " + desiredLinkRecall + " recall"); } private List<LinkInfo> getCandidates(int wpId, String text) throws DaoException { return getCandidates(text); // We should do something smarter with the text. } private List<LinkInfo> getCandidates(String text) throws DaoException { List<LinkInfo> candidates = new ArrayList<LinkInfo>(); StringTokenizer tokenizer = new StringTokenizer(); for (Token sentence : tokenizer.getSentenceTokens(language, text)) { for (Token phrase : phraseTokenizer.makePhraseTokens(language, sentence)) { double p = linkProbDao.getLinkProbability(phrase.getToken()); if (p > minLinkProbability) { LinkInfo li = new LinkInfo(phrase); li.setLinkProbability(p); candidates.add(li); } } } return candidates; } @Override public List<LocalLink> wikify(int wpId, String text) throws DaoException { // Find all mentions that are linked with some likelihood List<LinkInfo> mentions = getCandidates(wpId, text); // Find disambiguation candidates for each possible mention for (LinkInfo li : mentions) { li.setPrior(phraseDao.getPhraseCounts(language, li.getAnchortext(), 5)); } // Calculate the relatedness of each mention to known links in the article TIntSet existingIds = getActualLinks(wpId); TIntDoubleMap sr = calculateConceptRelatedness(existingIds, mentions); // Score every possible mention for (LinkInfo li : mentions) { scoreInfo(existingIds, li, sr); } return link(wpId, text, mentions); } @Override public List<LocalLink> wikify(int wpId) throws DaoException { RawPage page = rawPageDao.getById(language, wpId); if (page == null) { return new ArrayList<LocalLink>(); } else { return wikify(wpId, page.getPlainText(false)); } } @Override public List<LocalLink> wikify(String text) throws DaoException { List<LinkInfo> mentions = getCandidates(text); // Temporarily score eveything based on link probability and prior for (LinkInfo li : mentions) { PrunedCounts<Integer> prior = phraseDao.getPhraseCounts(language, li.getAnchortext(), 5); li.setPrior(prior); if (prior == null || prior.isEmpty()) continue; double p = 1.0 * prior.values().iterator().next() / (prior.getTotal() + 1); li.setScore(Math.sqrt(li.getLinkProbability()) * p); } // Take the top scoring items as existing ids Collections.sort(mentions); TIntSet existingIds = new TIntHashSet(); for (int i = 0; i < mentions.size(); i++) { LinkInfo li = mentions.get(i); if (li.getPrior() == null || li.getPrior().isEmpty()) continue; double p = 1.0 * li.getPrior().values().iterator().next() / (li.getPrior().getTotal() + 1); // String name = phraseDao.getPageCounts(language, li.getTopPriorDestination(), 1).keySet().iterator().next(); if ((li.getScore() > 0.01 && i < 3 && p >= 0.5) || (li.getScore() > 0.25 && p >= 0.5)) { existingIds.add(li.getTopPriorDestination()); } } TIntDoubleMap sr = calculateConceptRelatedness(existingIds, mentions); // Score every possible mention for (LinkInfo li : mentions) { scoreInfo(existingIds, li, sr); } return link(-1, text, mentions); } private void scoreInfo(TIntSet existingIds, LinkInfo li, TIntDoubleMap sr) { if (li.getPrior() == null || li.getPrior().isEmpty()) { return; } Scoreboard<Integer> scores = li.getScores(); for (int id : li.getPrior().keySet()) { double score = 0.4 * sr.get(id) + 0.6 * li.getPrior().get(id) / li.getPrior().getTotal(); score *= li.getLinkProbability(); if (existingIds.contains(id)) { score += 0.2; } scores.add(id, score); } li.setDest(scores.getElement(0)); double multiplier = (scores.size() == 1) ? 0.2 : (scores.getScore(0) - scores.getScore(1)); li.setScore(scores.getScore(0) * multiplier); } private TIntSet getActualLinks(int wpId) throws DaoException { TIntSet existingIds = new TIntHashSet(); for (LocalLink ll : linkDao.getLinks(language, wpId, true)) { if (ll.getDestId() >= 0) { existingIds.add(ll.getDestId()); } } // hack: add the link itself existingIds.add(wpId); return existingIds; } private TIntDoubleMap calculateConceptRelatedness(TIntSet existingIds, List<LinkInfo> infos) throws DaoException { TIntSet candidateIds = new TIntHashSet(); for (LinkInfo li : infos) { if (li.getPrior() != null) { for (int id : li.getPrior().keySet()) { candidateIds.add(id); } } } int existing[] = existingIds.toArray(); int candidates[] = candidateIds.toArray(); TIntDoubleMap results = new TIntDoubleHashMap(); if (existing.length == 0 || candidates.length == 0) { return results; } double [][] cosim = metric.cosimilarity(candidates, existing); for (int i = 0; i < candidates.length; i++) { double sum = 0.0; for (double s : cosim[i]) { if (!Double.isInfinite(s) && !Double.isNaN(s)) { sum += s; } } results.put(candidates[i], sum / existing.length); } return results; } public void setMinFinalScore(double minFinalScore) { this.minFinalScore = minFinalScore; } private List<LocalLink> link(int wpId, String text, List<LinkInfo> infos) throws DaoException { BitSet used = new BitSet(text.length()); List<LocalLink> results = identityWikifier.wikify(wpId, text); for (LocalLink li : results) { used.set(li.getLocation(), li.getLocation() + li.getAnchorText().length()); } Collections.sort(infos); for (LinkInfo li : infos) { if (li.getDest() != null && li.getScore() > minFinalScore && used.get(li.getStartChar(), li.getEndChar()).isEmpty()) { results.add(li.toLocalLink(language, wpId)); used.set(li.getStartChar(), li.getEndChar()); } } Collections.sort(results, new Comparator<LocalLink>() { @Override public int compare(LocalLink o1, LocalLink o2) { return o1.getLocation() - o2.getLocation(); } }); return results; } public static class Provider extends org.wikibrain.conf.Provider<Wikifier> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<Wikifier> getType() { return Wikifier.class; } @Override public String getPath() { return "sr.wikifier"; } @Override public Wikifier get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("Wikifier requires 'language' runtime parameter."); } if (!config.getString("type").equals("websail")) { return null; } Language language = Language.getByLangCode(runtimeParams.get("language")); Configurator c = getConfigurator(); String srName = config.getString("sr"); String phraseName = config.getString("phraseAnalyzer"); String identityName = config.getString("identityWikifier"); String linkName = config.getString("localLinkDao"); LinkProbabilityDao lpd = Env.getComponent(c, LinkProbabilityDao.class, language); if (config.getBoolean("useLinkProbabilityCache")) { lpd.useCache(true); } try { return new WebSailWikifier( c.get(Wikifier.class, identityName, "language", language.getLangCode()), c.get(RawPageDao.class), c.get(LocalLinkDao.class, linkName), lpd, ((AnchorTextPhraseAnalyzer)c.get(PhraseAnalyzer.class, phraseName)).getDao(), c.get(SRMetric.class, srName, "language", language.getLangCode()) ); } catch (DaoException e) { throw new ConfigurationException(e); } } } }