package org.wikibrain.sr; import com.typesafe.config.Config; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.apache.commons.cli.*; import org.apache.commons.io.FileUtils; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.DefaultOptionBuilder; import org.wikibrain.core.WikiBrainException; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.LocalLinkDao; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.download.FileDownloader; import org.wikibrain.phrases.LinkProbabilityDao; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.dataset.DatasetDao; import org.wikibrain.sr.dataset.FakeDatasetCreator; import org.wikibrain.sr.ensemble.EnsembleMetric; import org.wikibrain.sr.esa.SRConceptSpaceGenerator; import org.wikibrain.sr.milnewitten.MilneWittenMetric; import org.wikibrain.sr.wikify.Corpus; import org.wikibrain.sr.word2vec.Word2VecGenerator; import org.wikibrain.sr.word2vec.Word2VecTrainer; import org.wikibrain.utils.WpIOUtils; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.net.URL; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.zip.GZIPInputStream; /** * A "script" to build the semantic relatedness models. * This script takes care to not load the metric in build() until after data directories are deleted. * * @author Shilad Sen */ public class SRBuilder { private static final Logger LOG = LoggerFactory.getLogger(SRBuilder.class); // The environment and configuration we will use. private final Env env; private final Configuration config; private Language language; private final File srDir; // The name of the metric we will use. // If null, corresponds to the configured default metric. private String metricName = null; private boolean deleteExistingData = true; // The maximum number of results private int maxResults = 500; // Information that corresponds to building the cosimilarity matrix private boolean buildCosimilarity = false; private TIntSet rowIds = null; private TIntSet colIds = null; // List of datasets that will be used private List<String> datasetNames; // If false, existing submetrics for ensemble and pairwsise sim that // are already built will not be rebuilt. private boolean skipBuiltMetrics = false; private TIntSet validMostSimilarIds = null; // We may need to create a fake gold standard for languages that don't have one. private boolean createFakeGoldStandard = false; private Dataset fakeGoldStandard = null; public static enum Mode { SIMILARITY, MOSTSIMILAR, BOTH } private Mode mode = Mode.BOTH; public SRBuilder(Env env, String metricName, Language language) throws ConfigurationException { this.env = env; this.language = language; this.config = env.getConfiguration(); this.srDir = new File(config.get().getString("sr.metric.path")); datasetNames = config.get().getStringList("sr.dataset.defaultsets"); // Properly resolve the default metric name. this.metricName = env.getConfigurator().resolveComponentName(SRMetric.class, metricName); if (!srDir.isDirectory()) { srDir.mkdirs(); } } public SRBuilder(Env env, String metricName) throws ConfigurationException { this(env, metricName, env.getDefaultLanguage()); } public synchronized SRMetric getMetric() throws ConfigurationException { return getMetric(metricName); } public synchronized SRMetric getMetric(String name) throws ConfigurationException { return env.getComponent(SRMetric.class, name, language); } /** * First deletes models if deleteExistingData is true, then builds the appropriate metrics. * @throws ConfigurationException * @throws DaoException * @throws IOException * @throws WikiBrainException */ public void build() throws ConfigurationException, DaoException, IOException, WikiBrainException, InterruptedException { if (deleteExistingData) { deleteDataDirectories(); } buildConceptsIfNecessary(); LOG.info("building metric " + metricName); for (String name : getSubmetrics(metricName)) { buildMetric(name); } } /** * This method takes care to not load the metric itself, and just deal in names. * Once the metric is loaded, it has already accessed its data files. * @throws ConfigurationException */ public void deleteDataDirectories() throws ConfigurationException { for (String name : getSubmetrics(metricName)) { File dir = FileUtils.getFile(srDir, name, language.getLangCode()); if (dir.exists()) { LOG.info("deleting metric directory " + dir); FileUtils.deleteQuietly(dir); } } } /** * Returns a list of metric names (including the passed in name) that are a submetric * of the specified metric. The metrics are topologically sorted by dependency, so the * parent metric will appear last. * * @param parentName * @return * @throws ConfigurationException */ public List<String> getSubmetrics(String parentName) throws ConfigurationException { String type = getMetricType(parentName); Config config = getMetricConfig(parentName); List<String> toAdd = new ArrayList<String>(); if (type.equals("ensemble") || type.equals("simple-ensemble")) { for (String child : config.getStringList("metrics")) { toAdd.addAll(getSubmetrics(child)); toAdd.add(child); } } else if (type.equals("sparsevector.mostsimilarconcepts")) { toAdd.addAll(getSubmetrics(config.getString("generator.basemetric"))); } else if (type.equals("milnewitten")) { toAdd.add(config.getString("inlink")); toAdd.add(config.getString("outlink")); } else if (config.hasPath("reliesOn")) { toAdd.addAll(config.getStringList("reliesOn")); } toAdd.add(parentName); List<String> results = new ArrayList<String>(); // Make sure things only appear once. We save the FIRST time they appear to preserve dependencies. for (String name : toAdd) { if (!results.contains(name)) { results.add(name); } } return results; } public void buildMetric(String name) throws ConfigurationException, DaoException, IOException, InterruptedException { LOG.info("building component metric " + name); String type = getMetricType(name); if (type.equals("densevector.word2vec")) { initWord2Vec(name); } SRMetric metric = getMetric(name); if (type.equals("ensemble")) { ((EnsembleMetric)metric).setTrainSubmetrics(false); // Do it by hand } else if (type.equals("sparsevector.mostsimilarconcepts")) { if (mode == Mode.SIMILARITY) { LOG.warn("metric " + name + " of type " + type + " requires mostSimilar... training BOTH"); mode = Mode.BOTH; } throw new UnsupportedOperationException("This block needs to occur earlier."); } else if (type.equals("milnewitten")){ ((MilneWittenMetric)metric).setTrainSubmetrics(false); } if (metric instanceof BaseSRMetric) { ((BaseSRMetric)metric).setBuildMostSimilarCache(buildCosimilarity); } Dataset ds = getDataset(); if (mode == Mode.SIMILARITY || mode == Mode.BOTH) { if (skipBuiltMetrics && metric.similarityIsTrained()) { LOG.info("metric " + name + " similarity() is already trained... skipping"); } else { metric.trainSimilarity(ds); } } if (mode == Mode.MOSTSIMILAR || mode == Mode.BOTH) { if (skipBuiltMetrics && metric.mostSimilarIsTrained()) { LOG.info("metric " + name + " mostSimilar() is already trained... skipping"); } else { Config config = getMetricConfig(name); int n = maxResults * EnsembleMetric.SEARCH_MULTIPLIER; TIntSet validIds = validMostSimilarIds; if (config.hasPath("maxResults")) { n = config.getInt("maxResults"); } if (config.hasPath("mostSimilarConcepts")) { String path = String.format("%s/%s.txt", config.getString("mostSimilarConcepts"), metric.getLanguage().getLangCode()); validIds = readIds(path); } metric.trainMostSimilar(ds, n, validIds); } } metric.write(); } private String localize(String str) { return str.replace("LANG", language.getLangCode()); } private void initWord2Vec(String name) throws ConfigurationException, IOException, DaoException, InterruptedException { Config config = getMetricConfig(name).getConfig("generator"); File model = Word2VecGenerator.getModelFile(config.getString("modelDir"), language); if (skipBuiltMetrics && model.isFile()) { return; } if (config.hasPath("prebuilt") && config.getBoolean("prebuilt")) { if (model.isFile()) { return; } File downloadPath = new File(localize(config.getString("binfile"))); if (!downloadPath.isFile()) { FileDownloader downloader = new FileDownloader(); downloader.download(new URL(localize(config.getString("url"))), downloadPath); } if (config.hasPath("languages") && !config.getStringList("languages").contains(language.getLangCode())) { throw new ConfigurationException( "word2vec model " + downloadPath + " does not support language" + language); } if (downloadPath.toString().toLowerCase().endsWith("gz")) { LOG.info("decompressing " + downloadPath + " to " + model); File tmp = File.createTempFile("word2vec", "bin"); try { FileUtils.deleteQuietly(tmp); GZIPInputStream gz = new GZIPInputStream(new FileInputStream(downloadPath)); FileUtils.copyInputStreamToFile(gz, tmp); gz.close(); model.getParentFile().mkdirs(); FileUtils.moveFile(tmp, model); } finally { FileUtils.deleteQuietly(tmp); } } else { FileUtils.copyFile(downloadPath, model); } return; } LinkProbabilityDao lpd = env.getComponent(LinkProbabilityDao.class, language); lpd.useCache(true); lpd.buildIfNecessary(); String corpusName = config.getString("corpus"); Corpus corpus = null; if (!corpusName.equals("NONE")) { corpus = env.getConfigurator().get(Corpus.class, config.getString("corpus"), "language", language.getLangCode()); if (!corpus.exists()) { corpus.create(); } } if (model.isFile() && (corpus == null || model.lastModified() > corpus.getCorpusFile().lastModified())) { return; } if (corpus == null) { throw new ConfigurationException( "word2vec metric " + name + " cannot build or find model!" + "configuration has no corpus, but model not found at " + model + "."); } Word2VecTrainer trainer = new Word2VecTrainer( env.getConfigurator().get(LocalPageDao.class), language); if (config.hasPath("dimensions")) { LOG.info("set number of dimensions to " + config.getInt("dimensions")); trainer.setLayer1Size(config.getInt("dimensions")); } if (config.hasPath("maxWords")) { LOG.info("set maxWords to " + config.getInt("maxWords")); trainer.setMaxWords(config.getInt("maxWords")); } if (config.hasPath("window")) { LOG.info("set window to " + config.getInt("maxWords")); trainer.setWindow(config.getInt("window")); } trainer.setKeepAllArticles(true); trainer.train(corpus.getDirectory()); trainer.save(model); } private void setValidMostSimilarIdsFromFile(String file) throws IOException { setValidMostSimilarIds(readIds(file)); } public void setValidMostSimilarIds(TIntSet validMostSimilarIds) { this.validMostSimilarIds = validMostSimilarIds; } private void buildConceptsIfNecessary() throws IOException, ConfigurationException, DaoException { boolean needsConcepts = false; for (String name : getSubmetrics(metricName)) { String type = getMetricType(name); if (type.equals("sparsevector.esa") || type.equals("sparsevector.mostsimilarconcepts")) { needsConcepts = true; } } if (!needsConcepts) { return; } File path = FileUtils.getFile( env.getConfiguration().get().getString("sr.concepts.path"), language.getLangCode() + ".txt" ); path.getParentFile().mkdirs(); // Check to see if concepts are already built if (path.isFile() && FileUtils.readLines(path).size() > 1) { return; } LOG.info("building concept file " + path.getAbsolutePath() + " for " + metricName); SRConceptSpaceGenerator gen = new SRConceptSpaceGenerator(language, env.getConfigurator().get(LocalLinkDao.class), env.getConfigurator().get(LocalPageDao.class)); gen.writeConcepts(path); LOG.info("finished creating concept file " + path.getAbsolutePath() + " with " + FileUtils.readLines(path).size() + " lines"); } public Dataset getDataset() throws ConfigurationException, DaoException { if (createFakeGoldStandard) { if (fakeGoldStandard == null) { Corpus c = env.getConfigurator().get( Corpus.class, "plain", "language", language.getLangCode()); try { if (!c.exists()) c.create(); FakeDatasetCreator creator = new FakeDatasetCreator(c); fakeGoldStandard = creator.generate(500); } catch (IOException e) { throw new DaoException(e); } } return fakeGoldStandard; } else { DatasetDao dao = env.getConfigurator().get(DatasetDao.class); List<Dataset> datasets = new ArrayList<Dataset>(); for (String name : datasetNames) { datasets.addAll(dao.getDatasetOrGroup(language, name)); // throws a DaoException if language is incorrect. } return new Dataset(datasets); // merge all datasets together into one. } } public String getMetricType() throws ConfigurationException { return getMetricType(metricName); } public String getMetricType(String name) throws ConfigurationException { Config config = getMetricConfig(name); String type = config.getString("type"); if (type.equals("densevector") || type.equals("sparsevector")) { type += "." + config.getString("generator.type"); } return type; } public Config getMetricConfig() throws ConfigurationException { return getMetricConfig(metricName); } public Config getMetricConfig(String name) throws ConfigurationException { return env.getConfigurator().getConfig(SRMetric.class, name); } public void setRowIdsFromFile(String path) throws IOException { rowIds = readIds(path); } public void setColIdsFromFile(String path) throws IOException { colIds = readIds(path); } public void setDatasetNames(List<String> datasetNames) { this.datasetNames = datasetNames; } public void setBuildCosimilarity(boolean buildCosimilarity) { this.buildCosimilarity = buildCosimilarity; } public void setMaxResults(int maxResults) { this.maxResults = maxResults; } public void setRowIds(TIntSet rowIds) { this.rowIds = rowIds; } public void setColIds(TIntSet colIds) { this.colIds = colIds; } public void setMode(Mode mode) { this.mode = mode; } public void setDeleteExistingData(boolean deleteExistingData) { this.deleteExistingData = deleteExistingData; } public void setSkipBuiltMetrics(boolean skipBuiltMetrics) { this.skipBuiltMetrics = skipBuiltMetrics; } public void setLanguage(Language language) {this.language = language; } private static TIntSet readIds(String path) throws IOException { TIntSet ids = new TIntHashSet(); BufferedReader reader = WpIOUtils.openBufferedReader(new File(path)); while (true) { String line = reader.readLine(); if (line == null) { break; } ids.add(Integer.valueOf(line.trim())); } reader.close(); return ids; } public void setCreateFakeGoldStandard(boolean createFakeGoldStandard) { this.createFakeGoldStandard = createFakeGoldStandard; } public static void main(String args[]) throws ConfigurationException, IOException, WikiBrainException, DaoException, InterruptedException { Options options = new Options(); //Number of Max Results(otherwise take from config) options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("max-results") .withDescription("maximum number of results") .create("r")); //Specify the Datasets options.addOption( new DefaultOptionBuilder() .hasArgs() .withLongOpt("gold") .withDescription("the set of gold standard datasets to train on") .create("g")); //Delete existing data models options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("delete") .withDescription("delete all existing SR data for the metric and its submetrics (true or false, default is true)") .create("d")); //Specify the Metrics options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("metric") .withDescription("set a local metric") .create("m")); // Row and column ids for most similar caches options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("rowids") .withDescription("page ids for rows of cosimilarity matrices (implies -s)") .create("p")); options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("colids") .withDescription("page ids for columns of cosimilarity matrices (implies -s)") .create("q")); // build the cosimilarity matrix options.addOption( new DefaultOptionBuilder() .withLongOpt("cosimilarity") .withDescription("build cosimilarity matrices") .create("s")); // sets the mode options.addOption( new DefaultOptionBuilder() .withLongOpt("mode") .hasArg() .withDescription("mode: similarity, mostsimilar, or both") .create("o")); // add option for valid most similar ids options.addOption( new DefaultOptionBuilder() .withLongOpt("validMostSimilarIds") .withDescription("Set valid most similar ids") .create("y")); // when building pairwise cosine and ensembles, don't rebuild already built sub-metrics. options.addOption( new DefaultOptionBuilder() .withLongOpt("skip-built") .withDescription("Don't rebuild already built bmetrics (implies -d false)") .create("k")); // when building pairwise cosine and ensembles, don't rebuild already built sub-metrics. options.addOption( new DefaultOptionBuilder() .withLongOpt("fake") .withDescription("Create a fake gold standard for the language.") .create("f")); EnvBuilder.addStandardOptions(options); CommandLineParser parser = new PosixParser(); CommandLine cmd; try { cmd = parser.parse(options, args); } catch (ParseException e) { System.err.println("Invalid option usage: " + e.getMessage()); new HelpFormatter().printHelp("SRBuilder", options); return; } Env env = new EnvBuilder(cmd).build(); String metric = cmd.hasOption("m") ? cmd.getOptionValue("m") : null; SRBuilder builder = new SRBuilder(env, metric); if (cmd.hasOption("g")) { builder.setDatasetNames(Arrays.asList(cmd.getOptionValues("g"))); } if (cmd.hasOption("p")) { builder.setRowIdsFromFile(cmd.getOptionValue("p")); builder.setBuildCosimilarity(true); } if (cmd.hasOption("q")) { builder.setColIdsFromFile(cmd.getOptionValue("q")); builder.setBuildCosimilarity(true); } if (cmd.hasOption("y")) { builder.setValidMostSimilarIdsFromFile(cmd.getOptionValue("y")); } if (cmd.hasOption("s")) { builder.setBuildCosimilarity(true); } if (cmd.hasOption("k")) { builder.setSkipBuiltMetrics(true); builder.setDeleteExistingData(false); } if (cmd.hasOption("d")) { builder.setDeleteExistingData(Boolean.valueOf(cmd.getOptionValue("d"))); } if (cmd.hasOption("o")) { builder.setMode(Mode.valueOf(cmd.getOptionValue("o").toUpperCase())); } if (cmd.hasOption("l")) { builder.setLanguage(Language.getByLangCode(cmd.getOptionValue("l"))); } if (cmd.hasOption("r")) { builder.setMaxResults(Integer.valueOf(cmd.getOptionValue("r"))); } if (cmd.hasOption("f")) { builder.setCreateFakeGoldStandard(true); } builder.build(); } }