package org.wikibrain.sr.milnewitten; import com.typesafe.config.Config; import gnu.trove.map.TIntFloatMap; import gnu.trove.map.hash.TIntFloatHashMap; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.LocalLinkDao; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.model.LocalLink; import org.wikibrain.core.model.NameSpace; import org.wikibrain.phrases.AnchorTextPhraseAnalyzer; import org.wikibrain.phrases.PhraseAnalyzer; import org.wikibrain.phrases.PrunedCounts; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.normalize.Normalizer; import org.wikibrain.sr.utils.SimUtils; import java.io.File; import java.io.IOException; import java.util.LinkedHashMap; import java.util.Map; /** * @author Shilad Sen */ public class SimpleMilneWitten implements SRMetric { private final String name; private final Language language; private final LocalPageDao pageDao; private final LocalLinkDao linkDao; private final AnchorTextPhraseAnalyzer phraseAnalyzer; private final int numArticles; private File dataDir; public SimpleMilneWitten(String name, Language language, LocalPageDao pageDao, LocalLinkDao linkDao, AnchorTextPhraseAnalyzer phraseAnalyzer) throws DaoException { this.name = name; this.language = language; this.pageDao = pageDao; this.linkDao = linkDao; this.phraseAnalyzer = phraseAnalyzer; this.numArticles = pageDao.getCount( new DaoFilter() .setLanguages(language) .setDisambig(false) .setRedirect(false) .setNameSpaces(NameSpace.ARTICLE)); } @Override public String getName() { return name; } @Override public Language getLanguage() { return language; } @Override public File getDataDir() { return dataDir; } @Override public void setDataDir(File dir) { this.dataDir = dir; } @Override public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException { double s1 = googleInlink(pageId1, pageId2); double s2 = cosineOutlink(pageId1, pageId2); return new SRResult(0.5 * s1 + 0.5 * s2); } private TIntSet getInlinks(int pageId1) throws DaoException { TIntSet inlinks = new TIntHashSet(); for (LocalLink ll : linkDao.getLinks(language, pageId1, false)) { inlinks.add(ll.getSourceId()); } return inlinks; } private TIntSet getOutlinks(int pageId1) throws DaoException { TIntSet outlinks = new TIntHashSet(); for (LocalLink ll : linkDao.getLinks(language, pageId1, true)) { outlinks.add(ll.getDestId()); } return outlinks; } private double googleInlink(int pageId1, int pageId2) throws DaoException { TIntSet inlinks1 = getInlinks(pageId1); TIntSet inlinks2 = getInlinks(pageId2); if (inlinks1.isEmpty() && inlinks2.isEmpty()) { return 0.0; } int a = inlinks1.size(); int b = inlinks2.size(); TIntSet intersection = new TIntHashSet(inlinks1.toArray()); intersection.retainAll(inlinks2); int ab = intersection.size(); return 1.0 - ( (Math.log(Math.max(a, b)) - Math.log(ab)) / (Math.log(numArticles) - Math.log(Math.min(a, b))) ); } private double cosineOutlink(int pageId1, int pageId2) throws DaoException { TIntSet outlinks1 = getOutlinks(pageId1); TIntSet outlinks2 = getOutlinks(pageId2); TIntFloatMap v1 = makeOutlinkVector(outlinks1); TIntFloatMap v2 = makeOutlinkVector(outlinks2); if (v1.isEmpty() || v2.isEmpty()) { return 0.0; } return SimUtils.cosineSimilarity(v1, v2); } private int getNumLinks(int wpId) throws DaoException { return linkDao.getCount(new DaoFilter().setLanguages(language).setSourceIds(wpId)); } private TIntFloatMap makeOutlinkVector(TIntSet links) throws DaoException { TIntFloatMap vector = new TIntFloatHashMap(); for (int wpId : links.toArray()) { vector.put(wpId, (float) Math.log(1.0 * numArticles / getNumLinks(wpId))); } return vector; } @Override public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException { LinkedHashMap<LocalId, Float> candidates1 = phraseAnalyzer.resolve(language, phrase1, 100); LinkedHashMap<LocalId, Float> candidates2 = phraseAnalyzer.resolve(language, phrase2, 100); if (candidates1 == null || candidates2 == null) { return null; } double highestScore = Double.NEGATIVE_INFINITY; for (LocalId lid1 : candidates1.keySet()) { for (LocalId lid2 : candidates2.keySet()) { double score = similarity(lid1.getId(), lid2.getId(), false).getScore(); if (score > highestScore) { highestScore = score; } } } double result = 0.0; double highestPop = Double.NEGATIVE_INFINITY; for (LocalId lid1 : candidates1.keySet()) { for (LocalId lid2 : candidates2.keySet()) { double pop = candidates1.get(lid1) * candidates2.get(lid2); double score = similarity(lid1.getId(), lid2.getId(), false).getScore(); if (score >= 0.4 * highestScore && pop >= highestPop) { highestPop = pop; result = score; } } } int n1 = getPhraseCount(phrase1 + " " + phrase2); int n2 = getPhraseCount(phrase2 + " " + phrase1); if (n1 + n2 > 0) { result += Math.log(n1 + n2 + 1) / 10; } return new SRResult(result); } private int getPhraseCount(String phrase) throws DaoException { PrunedCounts<Integer> pages = phraseAnalyzer.getDao().getPhraseCounts(language, phrase, 1); if (pages == null) { return 0; } else { return pages.getTotal(); } } @Override public SRResultList mostSimilar(int pageId, int maxResults) throws DaoException { throw new UnsupportedOperationException(); } @Override public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException { throw new UnsupportedOperationException(); } @Override public SRResultList mostSimilar(String phrase, int maxResults) throws DaoException { throw new UnsupportedOperationException(); } @Override public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException { throw new UnsupportedOperationException(); } @Override public void write() throws IOException {} @Override public void read() {} @Override public void trainSimilarity(Dataset dataset) throws DaoException { } @Override public void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) { } @Override public boolean similarityIsTrained() { return true; } @Override public boolean mostSimilarIsTrained() { return false; } @Override public double[][] cosimilarity(int[] wpRowIds, int[] wpColIds) throws DaoException { throw new UnsupportedOperationException(); } @Override public double[][] cosimilarity(String[] rowPhrases, String[] colPhrases) throws DaoException { throw new UnsupportedOperationException(); } @Override public double[][] cosimilarity(int[] ids) throws DaoException { throw new UnsupportedOperationException(); } @Override public double[][] cosimilarity(String[] phrases) throws DaoException { throw new UnsupportedOperationException(); } @Override public Normalizer getMostSimilarNormalizer() { throw new UnsupportedOperationException(); } @Override public void setMostSimilarNormalizer(Normalizer n) { throw new UnsupportedOperationException(); } @Override public Normalizer getSimilarityNormalizer() { throw new UnsupportedOperationException(); } @Override public void setSimilarityNormalizer(Normalizer n) { throw new UnsupportedOperationException(); } public static class Provider extends org.wikibrain.conf.Provider<SRMetric> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<SRMetric> getType() { return SRMetric.class; } @Override public String getPath() { return "sr.metric.local"; } @Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("simplemilnewitten")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("SimpleMilneWitten requires 'language' runtime parameter."); } Language language = Language.getByLangCode(runtimeParams.get("language")); try { return new SimpleMilneWitten( name, language, getConfigurator().get(LocalPageDao.class), getConfigurator().get(LocalLinkDao.class), (AnchorTextPhraseAnalyzer) getConfigurator().get(PhraseAnalyzer.class, "anchortext") ); } catch (DaoException e) { throw new ConfigurationException(e); } } } }