package org.wikibrain.sr.vector; import com.typesafe.config.Config; import gnu.trove.map.TIntFloatMap; import gnu.trove.map.TIntIntMap; import gnu.trove.map.hash.TIntIntHashMap; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.model.NameSpace; import org.wikibrain.matrix.MatrixRow; import org.wikibrain.matrix.SparseMatrix; import org.wikibrain.matrix.SparseMatrixRow; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.utils.Leaderboard; import org.wikibrain.sr.utils.SimUtils; import org.wikibrain.utils.WpIOUtils; import java.io.File; import java.io.IOException; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Google Similarity measure, as described in Milne and Witten: * * http://www.cs.waikato.ac.nz/~ihw/papers/08-DM-IHW-Semantic_relatedness.pdf * * @author Shilad Sen */ public class GoogleSimilarity implements VectorSimilarity { private static final Logger LOG = LoggerFactory.getLogger(CosineSimilarity.class); private TIntIntMap lengths = new TIntIntHashMap(); // lengths of each row private TIntSet idsInResults = new TIntHashSet(); private final int numPages; private SparseMatrix features; private SparseMatrix transpose; public GoogleSimilarity(int numPages) { this.numPages = numPages; } @Override public synchronized void setMatrices(SparseMatrix features, SparseMatrix transpose, File dataDir) throws IOException { this.features = features; this.transpose = transpose; File idCacheFile = new File(dataDir, "googleSimilarity-ids.bin"); File lengthCacheFile = new File(dataDir, "googleSimilarity-lengths.bin"); if (lengthCacheFile.exists() && lengthCacheFile.lastModified() >= features.lastModified() && idCacheFile.exists() && idCacheFile.lastModified() >= transpose.lastModified()) { LOG.info("reading matrix information from cache"); lengths = (TIntIntMap) WpIOUtils.readObjectFromFile(lengthCacheFile); idsInResults = (TIntSet) WpIOUtils.readObjectFromFile(idCacheFile); } else { LOG.info("building cached matrix information"); lengths.clear(); idsInResults.clear(); for (SparseMatrixRow row : features) { lengths.put(row.getRowIndex(), row.getNumCols()); } idsInResults.addAll(transpose.getRowIds()); WpIOUtils.writeObjectToFile(lengthCacheFile, lengths); WpIOUtils.writeObjectToFile(idCacheFile, idsInResults); } } @Override public double similarity(TIntFloatMap vector1, TIntFloatMap vector2) { if (vector2.size() < vector1.size()) { TIntFloatMap tmp = vector1; vector1 = vector2; vector2 = tmp; } int size1 = vector1.size(); int size2 = vector2.size(); if (size1 == 0 || size2 == 0) { return 0.0; } int intersect = 0; for (int id1 : vector1.keys()) { if (vector2.containsKey(id1)) { intersect++; } } if (intersect == 0) { return 0.0; } return SimUtils.googleSimilarity(size1, size2, intersect, numPages); } @Override public double similarity(MatrixRow a, MatrixRow b) { int na = a.getNumCols(); int nb = b.getNumCols(); int intersect = 0; int i = 0, j = 0; if((na == 0 || nb == 0)) { // do not perform calculations if one or both are 0 return 0; } // Start by getting the first column in each matrix int ca = a.getColIndex(i); int cb = b.getColIndex(j); while (i < na && j < nb) { if (ca < cb) { // if matrix a has a lower value, then get the next column i++; ca = a.getColIndex(i); } else if (ca > cb) { // if matrix b has a lower value, then get the next column j++; cb = b.getColIndex(j); } else { // if both have the same value, increment the intersection and get the next columns in both matrices i++; j++; intersect++; ca = a.getColIndex(i); cb = b.getColIndex(j); } } return SimUtils.googleSimilarity(na, nb, intersect, numPages); } @Override public SRResultList mostSimilar(TIntFloatMap query, int maxResults, TIntSet validIds) throws IOException { TIntIntMap intersections = new TIntIntHashMap(); for (int id1 : query.keys()){ SparseMatrixRow row = transpose.getRow(id1); for (int i = 0; i < row.getNumCols(); i++) { int id2 = row.getColIndex(i); if (validIds == null || validIds.contains(id2)) { intersections.adjustOrPutValue(id2, 1, 1); } } } Leaderboard leaderboard = new Leaderboard(maxResults); for (int id: intersections.keys()) { double sim = SimUtils.googleSimilarity(query.size(), lengths.get(id), intersections.get(id), numPages); leaderboard.tallyScore(id, sim); } return leaderboard.getTop(); } @Override public double getMinValue() { return -1.0; } @Override public double getMaxValue() { return 1.0; } public static class Provider extends org.wikibrain.conf.Provider<VectorSimilarity> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return VectorSimilarity.class; } @Override public String getPath() { return "sr.metric.similarity"; } @Override public VectorSimilarity get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("google")) { return null; } if (!runtimeParams.containsKey("language")) { throw new IllegalArgumentException("GoogleSimilarity requires 'language' runtime parameter"); } Language language = Language.getByLangCode(runtimeParams.get("language")); LocalPageDao lpDao = getConfigurator().get(LocalPageDao.class); try { int numPages = lpDao.getCount( new DaoFilter() .setLanguages(language) .setRedirect(false) .setDisambig(false) .setNameSpaces(NameSpace.ARTICLE)); return new GoogleSimilarity(numPages); } catch (DaoException e) { throw new ConfigurationException(e); } } } }