package org.wikibrain.sr; import org.apache.commons.cli.*; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.io.FileUtils; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.conf.DefaultOptionBuilder; import org.wikibrain.core.WikiBrainException; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LanguageSet; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.dataset.DatasetDao; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; /** * @author Matt Lesciko * @author Ben Hillmann */ public class MetricTrainer { public static void main(String[] args) throws ConfigurationException, DaoException, IOException, WikiBrainException { Options options = new Options(); //Number of Max Results(otherwise take from config) options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("max-results") .withDescription("maximum number of results") .create("r")); //Specify the Datasets options.addOption( new DefaultOptionBuilder() .hasArgs() .withLongOpt("gold") .withDescription("the set of gold standard datasets to train on") .create("g")); //Specify the Metrics options.addOption( new DefaultOptionBuilder() .hasArg() .withLongOpt("metric") .withDescription("set a local metric") .create("m")); EnvBuilder.addStandardOptions(options); CommandLineParser parser = new PosixParser(); CommandLine cmd; try { cmd = parser.parse(options, args); } catch (ParseException e) { System.err.println("Invalid option usage: " + e.getMessage()); new HelpFormatter().printHelp("MetricTrainer", options); return; } Env env = new EnvBuilder(cmd) .setProperty("sr.metric.training", true) .build(); Configurator c = env.getConfigurator(); if (!cmd.hasOption("m")&&!cmd.hasOption("u")){ System.err.println("Must specify a metric to train using -m or -u."); new HelpFormatter().printHelp("MetricTrainer", options); return; } int maxResults = cmd.hasOption("r")? Integer.parseInt(cmd.getOptionValue("r")) : c.getConf().get().getInt("sr.normalizer.defaultmaxresults"); String path = c.getConf().get().getString("sr.metric.path"); LanguageSet allLangs = env.getLanguages(); DatasetDao datasetDao = env.getConfigurator().get(DatasetDao.class); List<String> datasetNames; if (cmd.hasOption("g")){ datasetNames = Arrays.asList(cmd.getOptionValues("g")); } else { datasetNames = c.getConf().get().getStringList("sr.dataset.defaultsets"); } List<Dataset> datasets = new ArrayList<Dataset>(); for (String name : datasetNames) { DatasetDao.Info info = datasetDao.getInfo(name); Collection<Language> possibleLang = CollectionUtils.intersection( info.getLanguages().getLanguages(), allLangs.getLanguages()); if (possibleLang.isEmpty()) { System.err.println("dataset " + name + " is a language other than " + allLangs); System.exit(1); } if (possibleLang.size() > 1) { System.err.println("dataset " + name + " supports more than one language of " + allLangs + " please specify"); System.exit(1); } Language lang = possibleLang.iterator().next(); if (datasets.size() > 0 && !lang.equals(datasets.get(0).getLanguage())) { System.err.println("Language mismatch in datasets " + name + " and " + datasets.get(0).getName()); System.exit(1); } datasets.add(datasetDao.get(lang, name)); } SRMetric sr=null; if (cmd.hasOption("m")){ Language language = datasets.get(0).getLanguage(); FileUtils.deleteDirectory(new File(path+cmd.getOptionValue("m")+"/"+"normalizer/")); sr = c.get(SRMetric.class,cmd.getOptionValue("m"), "language", language.getLangCode()); } Dataset dataset = new Dataset(datasets); if (sr!=null){ sr.trainMostSimilar(dataset, maxResults, null); sr.trainSimilarity(dataset); sr.write(); sr.read(); } } }