package org.wikibrain.sr.disambig; import com.typesafe.config.Config; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.core.model.LocalPage; import org.wikibrain.phrases.PhraseAnalyzer; import org.wikibrain.utils.WpCollectionUtils; import java.util.*; /** * @author Matt Lesicko */ public class TopResultConsensusDisambiguator extends Disambiguator { private List<PhraseAnalyzer> phraseAnalyzers; public TopResultConsensusDisambiguator(List<PhraseAnalyzer> phraseAnalyzers){ this.phraseAnalyzers=phraseAnalyzers; } public LocalId disambiguateTop(LocalString phrase, Set<LocalString> context) throws DaoException{ LinkedHashMap<LocalId, Integer> results = new LinkedHashMap<LocalId, Integer>(); for (PhraseAnalyzer phraseAnalyzer : phraseAnalyzers){ LinkedHashMap<LocalId, Float> localMap = phraseAnalyzer.resolve(phrase.getLanguage(), phrase.getString(), 1); if (localMap==null||localMap.isEmpty()){ continue; } LocalId localId = localMap.keySet().iterator().next(); if (results.containsKey(localId)){ results.put(localId,results.get(localId)+1); } else { results.put(localId,1); } } if (results.isEmpty()){ return null; } else { LocalId best=null; int score = 0; for (LocalId localId : results.keySet()){ if (results.get(localId)>score){ score = results.get(localId); best = localId; } } return best; } } public List<LocalId> disambiguateTop(List<LocalString> phrases, Set<LocalString> context) throws DaoException{ List<LocalId> ids = new ArrayList<LocalId>(); for (LocalString phrase : phrases){ ids.add(disambiguateTop(phrase, context)); } return ids; } @Override public List<LinkedHashMap<LocalId, Float>> disambiguate(List<LocalString> phrases, Set<LocalString> context) throws DaoException { if (phrases.isEmpty()) { return new ArrayList<LinkedHashMap<LocalId, Float>>(); } Language lang = phrases.get(0).getLanguage(); List<LinkedHashMap<LocalId, Float>> results = new ArrayList<LinkedHashMap<LocalId, Float>>(); for (LocalString phrase : phrases) { Map<Integer, Double> pageSums = new HashMap<Integer, Double>(); for (PhraseAnalyzer pa : phraseAnalyzers) { LinkedHashMap<LocalId, Float> probs = pa.resolve(phrase.getLanguage(), phrase.getString(), 20); for (Map.Entry<LocalId, Float> entry : probs.entrySet()) { int id = entry.getKey().getId(); if (pageSums.containsKey(id)) { pageSums.put(id, pageSums.get(id) + entry.getValue()); } else { pageSums.put(id, (double)entry.getValue()); } } } LinkedHashMap<LocalId, Float> pageResult = new LinkedHashMap<LocalId, Float>(); for (Integer key : WpCollectionUtils.sortMapKeys(pageSums, true)) { pageResult.put(new LocalId(lang, key), pageSums.get(key).floatValue()); } results.add(pageResult); } return results; } public static class Provider extends org.wikibrain.conf.Provider<Disambiguator>{ public Provider (Configurator configurator, Configuration config) throws ConfigurationException { super(configurator,config); } @Override public Class getType(){ return Disambiguator.class; } @Override public String getPath(){ return "sr.disambig"; } @Override public Disambiguator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException{ if (!config.getString("type").equals("topResultConsensus")){ return null; } List<PhraseAnalyzer> phraseAnalyzers = new ArrayList<PhraseAnalyzer>(); for (String analyzer : config.getStringList("phraseAnalyzers")){ phraseAnalyzers.add(getConfigurator().get(PhraseAnalyzer.class,analyzer)); } return new TopResultConsensusDisambiguator(phraseAnalyzers); } } }