package org.wikibrain.sr.vector; import com.typesafe.config.Config; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.TIntFloatMap; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.map.hash.TIntFloatHashMap; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.lucene.LuceneSearcher; import org.wikibrain.lucene.TextFieldElements; import org.wikibrain.lucene.WikiBrainScoreDoc; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.utils.WpCollectionUtils; import java.util.*; /** * Config looks like: * * { * weights: { * dab : 1.0 * sr : 1.0 * text : 1.0 * } * * numCandidates : { * sr : 10 * perSr : 2 * text : 50 * used : 20 * } * } * * A detailed description appears in the reference.conf * * @author Shilad Sen */ public class PhraseVectorCreator { private final LuceneSearcher searcher; private Language language; private SparseVectorSRMetric metric; private Disambiguator disambig; private SparseVectorGenerator generator; private double dabWeight = 1.0; private int numDabCands = 1; private double srWeight = 1.0; private int numSrCands = 0; private int numPerSrCand = 0; private double textWeight = 0.4; private int numTextCands = 50; private int numUsedCands = 20; public PhraseVectorCreator(LuceneSearcher searcher) { this.searcher = searcher; } public void setDabWeight(double dabWeight) { this.dabWeight = dabWeight; } public void setSrWeight(double srWeight) { this.srWeight = srWeight; } public void setNumSrCands(int numSrCands) { this.numSrCands = numSrCands; } public void setNumPerSrCand(int numPerSrCand) { this.numPerSrCand = numPerSrCand; } public void setTextWeight(double textWeight) { this.textWeight = textWeight; } public void setNumTextCands(int numTextCands) { this.numTextCands = numTextCands; } public void setNumUsedCands(int numUsedCands) { this.numUsedCands = numUsedCands; } public void setNumDabCands(int numDabCands) { this.numDabCands = numDabCands; } /** * Set metric must be called before this component can be used. * @param metric */ public void setMetric(SparseVectorSRMetric metric) { this.metric = metric; this.language = metric.getLanguage(); this.disambig = metric.getDisambiguator(); this.generator = metric.getGenerator(); } public TIntFloatMap[] getPhraseVectors(String ... phrases) throws DaoException { List<LocalString> local = new ArrayList<LocalString>(); for (String p : phrases) { local.add(new LocalString(language, p)); } List<LinkedHashMap<LocalId, Float>> candidates = disambig.disambiguate(local, null); if (candidates.size() != phrases.length) throw new IllegalStateException(); TIntFloatMap results[] = new TIntFloatMap[phrases.length]; for (int i = 0; i < phrases.length; i++) { results[i] = getPhraseVector(phrases[i], candidates.get(i)); } return results; } public TIntFloatMap getPhraseVector(String phrase) throws DaoException { LocalString ls = new LocalString(language, phrase); LinkedHashMap<LocalId, Float> candidates = disambig.disambiguate(ls, null); return getPhraseVector(phrase, candidates); } private TIntFloatMap getPhraseVector(String phrase, LinkedHashMap<LocalId, Float> dabCandidates) throws DaoException { if (dabCandidates == null || dabCandidates.isEmpty()) { return null; } LinkedHashMap<LocalId, Float> textCandidates = resolveTextual(phrase, numTextCands); LinkedHashMap<LocalId, Float> srCandidates = expandSR(phrase, dabCandidates, numSrCands, numPerSrCand); // StringBuffer buff = new StringBuffer("for phrase " + phrase + "\n"); TIntDoubleMap merged = new TIntDoubleHashMap(); double total = 0.0; int i = 0; for (Map.Entry<LocalId, Float> entry : dabCandidates.entrySet()) { if (i++ > numDabCands) { break; } // buff.append("\tdab: " + getTitle(entry.getKey()) + ": " + entry.getValue() + " * 1.0\n"); double v = entry.getValue() * dabWeight; merged.adjustOrPutValue(entry.getKey().getId(), v, v); total += v; } for (Map.Entry<LocalId, Float> entry : textCandidates.entrySet()) { // buff.append("\ttext: " + getTitle(entry.getKey()) + ": " + entry.getValue() + " * 1.0\n"); double v = entry.getValue() * textWeight; merged.adjustOrPutValue(entry.getKey().getId(), v, v); total += v; } for (Map.Entry<LocalId, Float> entry : srCandidates.entrySet()) { // buff.append("\tsr: " + getTitle(entry.getKey()) + ": " + entry.getValue() + " * 1.0\n"); double v = entry.getValue() * srWeight; merged.adjustOrPutValue(entry.getKey().getId(), v, v); total += v; } // System.out.println(buff.toString() + "\n\n\n"); int ids[] = WpCollectionUtils.sortMapKeys(merged, true); TIntFloatMap vector = new TIntFloatHashMap(); for (i = 0; i < numUsedCands && i < ids.length; i++) { TIntFloatMap candidateVector = generator.getVector(ids[i]); if (candidateVector != null) { for (int id : candidateVector.keys()) { double w = Math.sqrt(merged.get(ids[i]) / total); double v = candidateVector.get(id); vector.adjustOrPutValue(id, (float)(w * v), (float)(w * v)); } } } if (vector.isEmpty()) { return null; } else { return vector; } } private String getTitle(LocalId id) throws DaoException { return metric.getLocalPageDao().getById(language, id.getId()).getTitle().toString(); } private LinkedHashMap<LocalId, Float> resolveTextual(String phrase, int n) { if (n == 0) { return new LinkedHashMap<LocalId, Float>(); } WikiBrainScoreDoc results[] = searcher.getQueryBuilderByLanguage(language) .setPhraseQuery(new TextFieldElements().addPlainText(), phrase) .setNumHits(n*2) .search(); double total = 0.0; for (WikiBrainScoreDoc doc : results) { total += doc.score; } LinkedHashMap<LocalId, Float> expanded = new LinkedHashMap<LocalId, Float>(); for (int i = 0; i < n && i < results.length; i++) { expanded.put(new LocalId(language, results[i].wpId), (float)(results[i].score / total)); } return expanded; } /** * Expands a set of disambiguation candidates to include semantically related entities. * @param phrase * @param candidates * @param numCands * @param numPerCand * @return * @throws DaoException */ private LinkedHashMap<LocalId, Float> expandSR(String phrase, LinkedHashMap<LocalId, Float> candidates, int numCands, int numPerCand) throws DaoException { if (candidates == null || candidates.isEmpty()) { return null; } if (numCands == 0 || numPerCand == 0) { return new LinkedHashMap<LocalId, Float>(); } LinkedHashMap<LocalId, Float> expanded = new LinkedHashMap<LocalId, Float>(); int i = 0; for (LocalId id1 : candidates.keySet()) { SRResultList sr = metric.mostSimilar(id1.getId(), numCands * 2); if (sr != null && sr.numDocs() > 0) { for (int j = 0; j < numPerCand && j < sr.numDocs(); j++) { expanded.put(new LocalId(language, sr.getId(j)), (float)(sr.getScore(j) * candidates.get(id1))); } if (i++ >= numCands) { break; } } } return expanded; } public static class Provider extends org.wikibrain.conf.Provider<PhraseVectorCreator> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return PhraseVectorCreator.class; } @Override public String getPath() { return "sr.metric.phraseVectorCreator"; } @Override public PhraseVectorCreator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { LuceneSearcher searcher = getConfigurator().get(LuceneSearcher.class, config.getString("lucene")); PhraseVectorCreator creator = new PhraseVectorCreator(searcher); if (config.hasPath("weights.dab")) { creator.setDabWeight(config.getDouble("weights.dab")); } if (config.hasPath("weights.sr")) { creator.setSrWeight(config.getDouble("weights.sr")); } if (config.hasPath("weights.text")) { creator.setTextWeight(config.getDouble("weights.text")); } if (config.hasPath("numCandidates.used")) { creator.setNumUsedCands(config.getInt("numCandidates.used")); } if (config.hasPath("numCandidates.dab")) { creator.setNumDabCands(config.getInt("numCandidates.dab")); } if (config.hasPath("numCandidates.text")) { creator.setNumTextCands(config.getInt("numCandidates.text")); } if (config.hasPath("numCandidates.sr")) { creator.setNumSrCands(config.getInt("numCandidates.sr")); } if (config.hasPath("numCandidates.perSr")) { creator.setNumPerSrCand(config.getInt("numCandidates.perSr")); } return creator; } } }