package org.wikibrain.sr.ensemble; import com.typesafe.config.Config; import gnu.trove.set.TIntSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.sr.*; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.sr.utils.KnownSim; import org.wikibrain.utils.*; import java.io.*; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author Matt Lesicko * @author Shilad Sen */ public class EnsembleMetric extends BaseSRMetric { private static final Logger LOG = LoggerFactory.getLogger(EnsembleMetric.class); public static final int MIN_SEARCH_DEPTH = 500; public static final int SEARCH_MULTIPLIER = 3; private List<SRMetric> metrics; private Ensemble ensemble; private boolean resolvePhrases = true; private boolean trainSubmetrics = true; public EnsembleMetric(String name, Language language, List<SRMetric> metrics, Ensemble ensemble, Disambiguator disambiguator, LocalPageDao pageHelper){ super(name, language, pageHelper, disambiguator); this.metrics=metrics; this.ensemble=ensemble; } public List<SRMetric> getMetrics() { return metrics; } public void setResolvePhrases(boolean resolvePhrases) { this.resolvePhrases = resolvePhrases; } @Override public SRConfig getConfig() { return new SRConfig(); } @Override public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException { List<SRResult> scores = new ArrayList<SRResult>(); for (SRMetric metric : metrics){ scores.add(metric.similarity(pageId1,pageId2,explanations)); } return normalize(ensemble.predictSimilarity(scores)); } @Override public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException { if (resolvePhrases) { return super.similarity(phrase1, phrase2, explanations); } List<SRResult> scores = new ArrayList<SRResult>(); for (SRMetric metric : metrics){ scores.add(metric.similarity(phrase1,phrase2,explanations)); } return normalize(ensemble.predictSimilarity(scores)); } @Override public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException { SRResultList mostSimilar= getCachedMostSimilar(pageId, maxResults, validIds); if (mostSimilar != null) { return mostSimilar; } List<SRResultList> scores = new ArrayList<SRResultList>(); for (SRMetric metric : metrics){ scores.add(metric.mostSimilar(pageId,getMaxResults(maxResults),validIds)); } SRResultList result = normalize(ensemble.predictMostSimilar(scores, maxResults, validIds)); return result; } @Override public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException { if (resolvePhrases) { return super.mostSimilar(phrase, maxResults, validIds); } List<SRResultList> scores = new ArrayList<SRResultList>(); for (SRMetric metric : metrics){ scores.add(metric.mostSimilar(phrase, getMaxResults(maxResults),validIds)); } return normalize(ensemble.predictMostSimilar(scores,maxResults, validIds)); } /** * Training cascades to base metrics. * @param dataset * @throws DaoException */ @Override public void trainSimilarity(final Dataset dataset) throws DaoException { if (trainSubmetrics) { for (SRMetric metric : metrics) { metric.trainSimilarity(dataset); } } final List<EnsembleSim> ensembleSims = new ArrayList<EnsembleSim>(); ParallelForEach.loop( dataset.getData(), new Procedure<KnownSim>() { @Override public void call(KnownSim ks) throws Exception { EnsembleSim es = new EnsembleSim(ks); for (SRMetric metric : metrics){ double score = Double.NaN; try { SRResult result = metric.similarity(ks.phrase1,ks.phrase2,false); if (result != null) { score = result.getScore(); } } catch (Exception e){ LOG.warn("Local sr metric " + metric.getName() + " failed for " + ks, e); } es.add(score, 0); } ensembleSims.add(es); } }, 100); ensemble.trainSimilarity(ensembleSims); super.trainSimilarity(dataset); } /** * Training cascades to base metrics. * TODO: adapt this to a MostSimilarDataset * @param dataset * @param numResults * @param validIds */ @Override public void trainMostSimilar(Dataset dataset, final int numResults, final TIntSet validIds){ if (getMostSimilarCache() != null) { clearMostSimilarCache(); } if (trainSubmetrics) { for (SRMetric metric : metrics){ metric.trainMostSimilar(dataset,numResults,validIds); } } List<EnsembleSim> ensembleSims = ParallelForEach.loop(dataset.getData(), new Function<KnownSim, EnsembleSim>() { public EnsembleSim call(KnownSim ks) throws DaoException { List<LocalString> localStrings = Arrays.asList( new LocalString(ks.language, ks.phrase1), new LocalString(ks.language, ks.phrase2) ); List<LocalId> ids = getDisambiguator().disambiguateTop(localStrings, null); if (ids.isEmpty() || ids.get(0).getId() <= 0) { return null; } int pageId = ids.get(0).getId(); EnsembleSim es = new EnsembleSim(ks); for (SRMetric metric : metrics) { double score = Double.NaN; int rank = -1; try { SRResultList dsl = metric.mostSimilar(pageId, getMaxResults(numResults), validIds); if (dsl != null && dsl.getIndexForId(ids.get(1).getId()) >= 0) { score = dsl.getScore(dsl.getIndexForId(ids.get(1).getId())); rank = dsl.getIndexForId(ids.get(1).getId()); } } catch (Exception e) { LOG.warn("Local sr metric " + metric.getName() + " failed for " + pageId, e); } finally { es.add(score, rank); } } return es; } }, 100); ensemble.trainMostSimilar(ensembleSims); super.trainMostSimilar(dataset, numResults, validIds); } private int getMaxResults(int numResults) { return Math.max(MIN_SEARCH_DEPTH, numResults * SEARCH_MULTIPLIER); } public void setTrainSubmetrics(boolean trainSubmetrics) { this.trainSubmetrics = trainSubmetrics; } @Override public void write() throws IOException { super.write(); ensemble.write(new File(getDataDir(), "ensemble").getAbsolutePath()); } @Override public void read() throws IOException{ super.read(); ensemble.read(new File(getDataDir(), "ensemble").getAbsolutePath()); } public static class Provider extends org.wikibrain.conf.Provider<SRMetric>{ public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return SRMetric.class; } @Override public String getPath() { return "sr.metric.local"; } @Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException{ if (!config.getString("type").equals("ensemble")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")) { throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter"); } Language language = Language.getByLangCode(runtimeParams.get("language")); if (!config.hasPath("metrics")){ throw new ConfigurationException("Ensemble metric has no base metrics to use."); } List<SRMetric> metrics = new ArrayList<SRMetric>(); for (String metric : config.getStringList("metrics")){ metrics.add(getConfigurator().get(SRMetric.class, metric, "language", language.getLangCode())); } LocalPageDao pageDao = getConfigurator().get(LocalPageDao.class,config.getString("pageDao")); int numArticles = 0; try { numArticles = pageDao.getCount(DaoFilter.normalPageFilter(language)); } catch (DaoException e) { throw new ConfigurationException(e); } Ensemble ensemble; if (config.getString("ensemble").equals("linear")){ ensemble = new CorrelationEnsemble(metrics.size(), numArticles); } else if (config.getString("ensemble").equals("even")){ ensemble = new EvenEnsemble(); } else { throw new ConfigurationException("I don't know how to do that ensemble."); } Disambiguator disambiguator = getConfigurator().get(Disambiguator.class,config.getString("disambiguator"), "language", language.getLangCode()); EnsembleMetric sr = new EnsembleMetric(name, language, metrics,ensemble,disambiguator,pageDao); if (config.hasPath("resolvephrases")) { sr.setResolvePhrases(config.getBoolean("resolvephrases")); } BaseSRMetric.configureBase(getConfigurator(), sr, config); return sr; } } }