package org.wikibrain.phrases; import com.typesafe.config.Config; import gnu.trove.map.TLongFloatMap; import gnu.trove.map.TLongIntMap; import gnu.trove.map.hash.TLongFloatHashMap; import gnu.trove.map.hash.TLongIntHashMap; import gnu.trove.set.TLongSet; import gnu.trove.set.hash.TLongHashSet; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.RawPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LanguageSet; import org.wikibrain.core.lang.StringNormalizer; import org.wikibrain.core.model.NameSpace; import org.wikibrain.core.model.RawPage; import org.wikibrain.core.nlp.StringTokenizer; import org.wikibrain.core.nlp.Token; import org.wikibrain.utils.*; import java.io.File; import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.Map; /** * @author Shilad Sen * * Calculates the probability that a section of text is hyperlinked. Useful for * detecting entities. */ public class LinkProbabilityDao { private static final Logger LOG = LoggerFactory.getLogger(LinkProbabilityDao.class); private final File path; private final Language lang; private final RawPageDao pageDao; private final PhraseAnalyzerDao phraseDao; private final StringNormalizer normalizer; private ObjectDb<Double> db; private TLongFloatMap cache = null; private TLongSet subGrams = null; public LinkProbabilityDao(File path, Language lang, RawPageDao pageDao, PhraseAnalyzerDao phraseDao) throws DaoException { this.path = path; this.lang = lang; this.pageDao = pageDao; this.phraseDao = phraseDao; this.normalizer = phraseDao.getStringNormalizer(); if (path.exists()) { try { db = new ObjectDb<Double>(path, false); } catch (IOException e) { throw new DaoException(e); } } else { LOG.warn("path " + path + " does not exist... LinkProbabilityDao will not work until build() is called."); } } /** * If true, create a memory cache that stores a 64-bit hashcode for each word. * If the cache doesn't exist, it will be built. * @param useCache */ public void useCache(boolean useCache) { if (!useCache) { this.cache = null; return; } else if (db == null) { this.cache = new TLongFloatHashMap(); // build cache later return; } File fp = new File(path + "-phrase-cache.bin"); File fsg = new File(path + "-subgram-cache.bin"); long tstamp = 0; try { Double doubleTstamp = db.get("tstamp"); if (doubleTstamp == null) { tstamp = System.currentTimeMillis(); db.put("tstamp", 1.0 * tstamp); db.flush(); } else { tstamp = db.get("tstamp").longValue(); } } catch (IOException e) { throw new RuntimeException(e); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } if (fp.isFile() && fp.lastModified() > tstamp && fsg.isFile() && fsg.lastModified() > tstamp) { try { cache = (TLongFloatMap) WpIOUtils.readObjectFromFile(fp); subGrams = (TLongSet) WpIOUtils.readObjectFromFile(fsg); LOG.info("Using up-to-date link probability cache files {} and {}", fp, fsg); return; } catch (IOException e) { LOG.warn("Using link probability dao cache failed: ", e); } } LOG.info("building cache..."); TLongFloatMap cache = new TLongFloatHashMap(); Iterator<Pair<String, Double>> iter = db.iterator(); TLongSet subgrams = new TLongHashSet(); while (iter.hasNext()) { Pair<String, Double> entry = iter.next(); if (entry.getKey().equalsIgnoreCase("tstamp")) { // do nothing... } else if (entry.getKey().startsWith(":s:")) { long hash = Long.valueOf(entry.getKey().substring(3)); subgrams.add(hash); } else { String tokens[] = entry.getKey().split(":", 2); Language lang = Language.getByLangCode(tokens[0]); long hash = hashCode(tokens[1]); cache.put(hash, entry.getRight().floatValue()); } } this.cache = cache; this.subGrams = subgrams; LOG.info("created cache with " + cache.size() + " entries and " + subgrams.size() + " subgrams"); try { WpIOUtils.writeObjectToFile(fp, cache); WpIOUtils.writeObjectToFile(fsg, subgrams); } catch (IOException e) { throw new RuntimeException(e); } } /** * Build the cache if it is not already built. * @throws DaoException */ public void buildIfNecessary() throws DaoException { if (!isBuilt()) build(); } /** * @return The language associated with this dao. */ public Language getLang() { return lang; } /** * Retrieves the probability a link is linked in Wikipedia. * If normalize is true, text normalization is first performed. * @param mention * @return * @throws DaoException */ public double getLinkProbability(String mention) throws DaoException { return getLinkProbability(mention, true); } /** * Retrieves the probability a link is linked in Wikipedia. * If normalize is true, text normalization is first performed. * @param mention * @param normalize If true, the text is normalized. * @return * @throws DaoException */ public double getLinkProbability(String mention, boolean normalize) throws DaoException { if (db == null) { throw new IllegalStateException("Dao has not yet been built. Call build()"); } String normalizedMention = cleanString(mention, normalize); if (cache != null && cache.size() > 0) { long hash = hashCode(normalizedMention); return cache.containsKey(hash) ? cache.get(hash) : 0.0; } String key = lang.getLangCode() + ":" + normalizedMention; Double d = null; try { d = db.get(key); } catch (IOException e) { throw new DaoException(e); } catch (ClassNotFoundException e) { throw new DaoException(e); } if (d == null) { return 0.0; } else { return d; } } /** * Rebuilds the link probability dao. Deletes the dao if it currently exists. * @throws DaoException */ public synchronized void build() throws DaoException { if (db != null) { db.close(); } if (path.exists()) { FileUtils.deleteQuietly(path); } path.mkdirs(); try { this.db = new ObjectDb<Double>(path, true); } catch (IOException e) { throw new DaoException(e); } subGrams = new TLongHashSet(); LOG.info("building link probabilities for language " + lang); final TLongIntMap counts = new TLongIntHashMap(); Iterator<String> iter = phraseDao.getAllPhrases(lang); StringTokenizer tokenizer = new StringTokenizer(); while (iter.hasNext()) { String phrase = iter.next(); List<String> words = tokenizer.getWords(lang, phrase); StringBuilder buffer = new StringBuilder(""); long hash = -1; for (int i = 0; i < words.size(); i++) { if (i > 0) buffer.append(' '); buffer.append(words.get(i)); hash = hashCode(buffer.toString()); subGrams.add(hash); } counts.put(hash, 0); } LOG.info("found " + counts.size() + " unique anchortexts and " + subGrams.size() + " subgrams"); DaoFilter filter = new DaoFilter() .setRedirect(false) .setLanguages(lang) .setDisambig(false) .setNameSpaces(NameSpace.ARTICLE); ParallelForEach.iterate( pageDao.get(filter).iterator(), WpThreadUtils.getMaxThreads(), 100, new Procedure<RawPage>() { @Override public void call(RawPage page) throws Exception { processPage(counts, page); } }, 10000); int count = 0; int misses = 0; double sum = 0.0; TLongSet completed = new TLongHashSet(); TLongIntMap linkCounts = getPhraseLinkCounts(); Iterator<Pair<String, PrunedCounts<Integer>>> phraseIter = phraseDao.getAllPhraseCounts(lang); while (phraseIter.hasNext()) { Pair<String, PrunedCounts<Integer>> pair = phraseIter.next(); String phrase = cleanString(pair.getLeft()); long hash = hashCode(phrase); if (completed.contains(hash)) { continue; } completed.add(hash); try { int numLinks = linkCounts.get(hash); int numText = counts.get(hash); if (numText == 0) { misses++; } count++; double p = 1.0 * numLinks / (numText + 3.0); // 3.0 for smoothing sum += p; // System.out.println(String.format("inserting values into db: %s, %f", pair.getLeft, p)); db.put(lang.getLangCode() + ":" + phrase, p); if (cache != null) { cache.put(hash, (float) p); } } catch (IOException e) { throw new DaoException(e); } } for (long h : subGrams.toArray()) { try { db.put(":s:" + h, -1.0); } catch (IOException e) { throw new DaoException(e); } } try { db.put("tstamp", 1.0 * System.currentTimeMillis()); } catch (IOException e) { throw new DaoException(e); } if (count != 0) { LOG.info(String.format( "Inserted link probabilities for %d anchors with mean probability %.4f and %d mises", count, sum / count, misses)); } db.flush(); } private void processPage(TLongIntMap counts, RawPage page) { Language lang = page.getLanguage(); StringTokenizer tokenizer = new StringTokenizer(); StringBuilder buffer = new StringBuilder(); for (Token sentence : tokenizer.getSentenceTokens(lang, page.getPlainText())) { List<Token> words = tokenizer.getWordTokens(lang, sentence); for (int i = 0; i < words.size(); i++) { buffer.setLength(0); for (int j = i; j < words.size(); j++) { if (j > i) { buffer.append(' '); } buffer.append(words.get(j).getToken()); String phrase = cleanString(buffer.toString(), true); long hash = hashCode(phrase); if (subGrams.contains(hash)) { synchronized (counts) { if (counts.containsKey(hash)) { // System.out.println("here 1: " + phrase); counts.adjustValue(hash, 1); } else { // System.out.println("here 2: " + phrase); } } } else { // System.out.println("here 3: " + phrase); break; // no point in going any further... } } } } } private TLongIntMap getPhraseLinkCounts() { Iterator<Pair<String, PrunedCounts<Integer>>> phraseIter = phraseDao.getAllPhraseCounts(lang); TLongIntMap counts = new TLongIntHashMap(); while (phraseIter.hasNext()) { Pair<String, PrunedCounts<Integer>> pair = phraseIter.next(); String phrase = cleanString(pair.getLeft()); long hash = hashCode(phrase); int n = pair.getRight().getTotal(); counts.adjustOrPutValue(hash, n, n); } return counts; } public boolean isBuilt() { return (db != null && !db.isEmpty()); } public boolean isSubgram(String phrase, boolean normalize) { if (cache == null || subGrams == null) { throw new IllegalArgumentException("Subgrams require a cache!"); } String cleaned = cleanString(phrase, normalize); long h = hashCode(cleaned); return cache.containsKey(h) || subGrams.contains(h); } private String cleanString(String s) { return cleanString(s, false); } private String cleanString(String s, boolean normalize) { if (normalize) s = normalizer.normalize(lang, s); StringTokenizer t = new StringTokenizer(); return StringUtils.join(t.getWords(lang, s), " "); } static long hashCode(String string) { return WpStringUtils.longHashCode2(string); } public static class Provider extends org.wikibrain.conf.Provider<LinkProbabilityDao> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<LinkProbabilityDao> getType() { return LinkProbabilityDao.class; } @Override public String getPath() { return "phrases.linkProbability"; } @Override public LinkProbabilityDao get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { LanguageSet ls = getConfigurator().get(LanguageSet.class); if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("LinkProbabilityDao requires 'language' runtime parameter."); } Language language = Language.getByLangCode(runtimeParams.get("language")); File path = new File(config.getString("path"), language.getLangCode()); String pageName = config.hasPath("rawPageDao") ? config.getString("rawPageDao") : null; String phraseName = config.hasPath("phraseAnalyzer") ? config.getString("phraseAnalyzer") : null; RawPageDao rpd = getConfigurator().get(RawPageDao.class, pageName); PhraseAnalyzer pa = getConfigurator().get(PhraseAnalyzer.class, phraseName); if (!(pa instanceof AnchorTextPhraseAnalyzer)) { throw new ConfigurationException("LinkProbabilityDao's phraseAnalyzer must be an AnchorTextPhraseAnalyzer"); } PhraseAnalyzerDao pad = ((AnchorTextPhraseAnalyzer)pa).getDao(); try { return new LinkProbabilityDao(path, language, rpd, pad); } catch (DaoException e) { throw new ConfigurationException(e); } } } }