package org.wikibrain.sr.wikify; import com.typesafe.config.Config; import gnu.trove.TCollections; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.core.dao.*; import org.wikibrain.core.lang.Language; import org.wikibrain.core.model.LocalLink; import org.wikibrain.core.model.NameSpace; import org.wikibrain.core.model.RawPage; import org.wikibrain.core.nlp.NGramCreator; import org.wikibrain.core.nlp.StringTokenizer; import org.wikibrain.core.nlp.Token; import org.wikibrain.phrases.*; import org.wikibrain.sr.SRMetric; import java.io.IOException; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author Shilad Sen */ public class MilneWittenWikifier implements Wikifier { private static final Logger LOG = LoggerFactory.getLogger(MilneWittenWikifier.class); private final LocalPageDao lpd; private final LocalLinkDao lld; private final RawPageDao rpd; private final SRMetric metric; private final PhraseAnalyzerDao phraseDao; private final LinkProbabilityDao linkProbDao; private final Language language; private int numTestingDocs = 100; private double minLinkProbability = 0.03; private int maxNGram = 3; private StringTokenizer tokenizer = new StringTokenizer(); private NGramCreator nGramCreator = new NGramCreator(); public MilneWittenWikifier(SRMetric metric, AnchorTextPhraseAnalyzer pa, LocalPageDao lpd, RawPageDao rpd, LocalLinkDao lld, LinkProbabilityDao linkProbDao) { this.lpd = lpd; this.linkProbDao = linkProbDao; this.phraseDao = pa.getDao(); this.metric = metric; this.rpd = rpd; this.lld = lld; this.language = metric.getLanguage(); } public void testWikify() throws DaoException { int barackId = lpd.getIdByTitle("Barack Obama", language, NameSpace.ARTICLE); RawPage rp = rpd.getById(language, barackId); for (int i = 0; i < 1; i++) { List<LocalLink> detected = wikify(rp.getLocalId()); System.out.println("Links detected for " + rp.getTitle() + " (" + i + ")"); for (LocalLink ll : detected) { System.out.println("\t" + ll + " page " + lpd.getById(language, ll.getDestId()).getTitle()); } } } private List<Token> getNGramTokens(String text) { List<Token> ngrams = new ArrayList<Token>(); for (Token sentence : tokenizer.getSentenceTokens(language, text)) { List<Token> words = tokenizer.getWordTokens(language, sentence); ngrams.addAll(nGramCreator.getNGramTokens(words, 1, maxNGram)); } return ngrams; } private double getLinkProbability(String phrase) throws DaoException { return linkProbDao.getLinkProbability(phrase); } @Override public List<LocalLink> wikify(int wpId, String text) throws DaoException { List<LinkInfo> candidates = getCandidates(text); identifyKnownCandidates(wpId, candidates); List<LinkInfo> detected = detectLinks(candidates); List<LocalLink> results = new ArrayList<LocalLink>(); for (LinkInfo li : detected) { results.add(new LocalLink(language, li.getAnchortext(), wpId, li.getDest(), true, li.getStartChar(), true, null)); } return results; } @Override public List<LocalLink> wikify(int wpId) throws DaoException { RawPage rp = rpd.getById(language, wpId); if (rp == null) { return new ArrayList<LocalLink>(); } return wikify(wpId, rp.getPlainText(false)); } @Override public List<LocalLink> wikify(String text) throws DaoException { List<LinkInfo> candidates = getCandidates(text); List<LinkInfo> detected = detectLinks(candidates); List<LocalLink> results = new ArrayList<LocalLink>(); for (LinkInfo li : detected) { results.add(new LocalLink(language, li.getAnchortext(), -1, li.getDest(), true, li.getStartChar(), true, null)); } // Sort by position Collections.sort(results, new Comparator<LocalLink>() { @Override public int compare(LocalLink l1, LocalLink l2) { return l1.getLocation() - l2.getLocation(); } }); return results; } private List<LinkInfo> detectLinks(List<LinkInfo> candidates) throws DaoException { Map<String, LinkInfo> scoreCache = new HashMap<String, LinkInfo>(); TIntDoubleMap relatedness = getRelatedness(candidates); for (LinkInfo li : candidates) { scoreLinkInfo(li, scoreCache, relatedness); } TIntSet used = new TIntHashSet(); // used characters Collections.sort(candidates); List<LinkInfo> detected = new ArrayList<LinkInfo>(); for (LinkInfo li : candidates) { if (li.getScore() < 0.01) { break; } if(!li.intersects(used)) { detected.add(li); li.markAsUsed(used); } // if (li.getDest() >= 0) { // System.out.println("link " + li.getAnchortext() + " to " + lpd.getById(language, li.getDest()) + " has score " + li.getScore()); // } } return detected; } private TIntDoubleMap getRelatedness(List<LinkInfo> candidates) throws DaoException { TIntSet knownSet = new TIntHashSet(); TIntSet candidateSet = new TIntHashSet(); for (LinkInfo li : candidates) { if (li.getKnownDest() != null) { knownSet.add(li.getKnownDest()); } else if (li.hasOnePossibility()) { knownSet.add(li.getTopPriorDestination()); } else { for (int wpId : li.getPrior().keySet()) { candidateSet.add(wpId); } } } int [] knownIds = knownSet.toArray(); int [] candidateIds = candidateSet.toArray(); double cosimilarity[][] = metric.cosimilarity(candidateIds, knownIds); TIntDoubleMap similarities = new TIntDoubleHashMap(); for (int i = 0; i < candidateIds.length; i++) { double sum = 0.0; for (double sim : cosimilarity[i]) { sum += sim; } similarities.put(candidateIds[i], sum / knownIds.length); } return similarities; } private void scoreLinkInfo(LinkInfo link, Map<String, LinkInfo> cache, TIntDoubleMap allRelatedness) throws DaoException { if (link.getKnownDest() != null) { link.setDest(link.getKnownDest()); link.setScore(1000000.0); return; } if (cache.containsKey(link.getAnchortext())) { LinkInfo existing = cache.get(link.getAnchortext()); link.setDest(existing.getDest()); link.setScore(existing.getScore()); return; } for (int wpId : link.getPrior().keySet()) { double prior = link.getPrior().get(wpId); double relatedness = allRelatedness.get(wpId); double score = prior * relatedness * link.getLinkProbability() * getGenerality(wpId); link.addScore(wpId, score); } if (link.getScores().size() == 0) { return; } link.setDest(link.getScores().getElement(0)); link.setScore(link.getScores().getScore(0)); if (link.getScores().size() == 1) { link.setScore(link.getScore() * 3); } else { double score2 = link.getScores().getScore(1); link.setScore(link.getScore() * Math.min(3.0, link.getScore() / score2)); } cache.put(link.getAnchortext(), link); } private final TIntDoubleMap generality = TCollections.synchronizedMap(new TIntDoubleHashMap()); private final int MAX_INLINKS = 1000; private double getGenerality(int wpId) throws DaoException { if (generality.containsKey(wpId)) { return generality.get(wpId); } int numInLinks = lld.getCount(new DaoFilter().setLanguages(language).setDestIds(wpId)); double g = 0.5 + Math.log(1 + Math.min(MAX_INLINKS, numInLinks)) / Math.log(1 + MAX_INLINKS); generality.put(numInLinks, numInLinks); return numInLinks; } private void identifyKnownCandidates(int wpId, List<LinkInfo> candidates) throws DaoException { Set<String> usedAnchors = new HashSet<String>(); /** * Hack: Mark the FIRST POSSIBLE of each candidate link as verified. */ for (LocalLink ll : lld.getLinks(language, wpId, true)) { if (ll.getDestId() < 0 || ll.getAnchorText() == null || usedAnchors.contains(ll.getAnchorText())) { continue; } for (LinkInfo li : candidates) { if (ll.getAnchorText().equals(li.getAnchortext())) { if (li.getKnownDest() != null) { LOG.info("conflict for link info " + li.getAnchortext() + " between " + li.getKnownDest() + " and " + ll.getDestId()); } else { li.setKnownDest(ll.getDestId()); break; } } } usedAnchors.add(ll.getAnchorText()); } } public List<LinkInfo> getTextContext(String text) throws DaoException { return getCandidates(text); } private List<LinkInfo> getCandidates(String text) throws DaoException { Map<String, LinkInfo> cache = new HashMap<String, LinkInfo>(); List<LinkInfo> candidates = new ArrayList<LinkInfo>(); for (Token ngram : getNGramTokens(text)) { LinkInfo li = makeLinkInfo(ngram, cache); if (li != null) { candidates.add(li); } } return candidates; } private LinkInfo makeLinkInfo(Token token, Map<String, LinkInfo> cache) throws DaoException { double linkProbability = getLinkProbability(token.getToken()); if (linkProbability < minLinkProbability) { return null; } if (cache.containsKey(token.getToken())) { LinkInfo old = cache.get(token.getToken()); LinkInfo li = new LinkInfo(); li.setLinkProbability(linkProbability); li.setAnchortext(token.getToken()); li.setStartChar(token.getBegin()); li.setEndChar(token.getEnd()); li.setPrior(old.getPrior()); return li; } PrunedCounts<Integer> counts = phraseDao.getPhraseCounts(language, token.getToken(), 30); if (counts != null && !counts.isEmpty()) { LinkInfo li = new LinkInfo(); li.setLinkProbability(linkProbability); li.setAnchortext(token.getToken()); li.setStartChar(token.getBegin()); li.setEndChar(token.getEnd()); li.setPrior(counts); cache.put(token.getToken(), li); return li; } else { return null; } } public static class Provider extends org.wikibrain.conf.Provider<Wikifier> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<Wikifier> getType() { return Wikifier.class; } @Override public String getPath() { return "sr.wikifier"; } @Override public Wikifier get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("Wikifier requires 'language' runtime parameter."); } if (!config.getString("type").equals("milnewitten")) { return null; } Language language = Language.getByLangCode(runtimeParams.get("language")); Configurator c = getConfigurator(); String srName = config.getString("sr"); String phraseName = config.getString("phraseAnalyzer"); String linkName = config.getString("localLinkDao"); LinkProbabilityDao lpd = Env.getComponent(c, LinkProbabilityDao.class, language); if (config.getBoolean("useLinkProbabilityCache")) { lpd.useCache(true); } Wikifier dab = new MilneWittenWikifier( c.get(SRMetric.class, srName, "language", language.getLangCode()), (AnchorTextPhraseAnalyzer)c.get(PhraseAnalyzer.class, phraseName), c.get(LocalPageDao.class), c.get(RawPageDao.class), c.get(LocalLinkDao.class, linkName), lpd ); return dab; } } public static void main(String args[]) throws ConfigurationException, DaoException, IOException { Env env = EnvBuilder.envFromArgs(args); Configurator c = env.getConfigurator(); MilneWittenWikifier w = c.get(MilneWittenWikifier.class, "default", "language", "simple"); w.testWikify(); } }