package org.wikibrain.sr.dataset; import com.typesafe.config.Config; import com.typesafe.config.ConfigValue; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringEscapeUtils; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LanguageSet; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.sr.utils.KnownSim; import org.wikibrain.utils.WpIOUtils; import java.io.*; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Reads and writes datasets. * * Supports reading builtin datasets from resource files. * * @author Shilad Sen * @author Matt Lesicko * @author Ben Hillmann */ public class DatasetDao { private static final Logger LOG = LoggerFactory.getLogger(Dataset.class); public static final String RESOURCE_DATSET = "/datasets"; public static final String RESOURCE_DATASET_INFO = "/datasets/info.tsv"; private final Collection<Info> info; private Map<String, List<String>> groups = new HashMap<String, List<String>>(); private boolean normalize = true; // If true, normalize all scores to [0,1] private boolean resolvePhrases = false; private Disambiguator disambiguator = null; /** * Information about a particular dataset */ public static class Info { private String name; private LanguageSet languages; public Info(String name, LanguageSet languages) { this.name = name; this.languages = languages; } public String getName() { return name; } public LanguageSet getLanguages() { return languages; } } /** * Creates a new dataset dao with particular configuration information. */ public DatasetDao() { try { this.info = readInfos(); } catch (DaoException e) { throw new RuntimeException(e); // errors shouldn't occur for compiled resources } } /** * Creates a new dataset dao with particular configuration information. * @param info */ public DatasetDao(Collection<Info> info) { this.info = info; } /** * If true, all datasets will be "normalized" to [0,1] scores. * @param normalize */ public void setNormalize(boolean normalize) { this.normalize = normalize; } public List<Dataset> getAllInLanguage(Language lang) throws DaoException { List<Dataset> result = new ArrayList<Dataset>(); for (Info i : info) { if (i.getLanguages().containsLanguage(lang)) { result.add(get(lang, i.getName())); } } return result; } /** * Reads a dataset from the classpath with a particular name. * Some datasets support multiple languages (i.e. simple and en). * * @param language The desired language * @param path The path to the dataset. * @return The dataset * @throws DaoException */ public Dataset read(Language language, File path) throws DaoException { try { return read(path.getName(), language, WpIOUtils.openBufferedReader(path)); } catch (IOException e) { throw new DaoException(e); } } /** * Reads a dataset from the classpath with a particular name. * Some datasets support multiple languages (i.e. simple and en). * The dataset name can also be a group name (e.g. en-major) * * @param language The desired language * @param name The name of the dataset. * @return The dataset * @throws DaoException */ public Dataset get(Language language, String name) throws DaoException { if (groups.containsKey(name)) { List<Dataset> members = new ArrayList<Dataset>(); for (String n : groups.get(name)) { members.add(get(language, n)); } return new Dataset(name, members); } if (name.contains("/") || name.contains("\\")) { throw new DaoException("get() reads a dataset by name for a jar. Try read() instead?"); } Info info = getInfo(name); if (info == null) { throw new DaoException("no dataset with name '" + name + "'"); } if (!info.languages.containsLanguage(language)) { throw new DaoException("dataset '" + name + "' does not support language " + language); } try { return read(name, language, WpIOUtils.openResource(RESOURCE_DATSET + "/" + name)); } catch (IOException e) { throw new DaoException(e); } } /** * Returns true if the name is the name of a group of datasets * @param name * @return */ public boolean isGroup(String name) { return groups.containsKey(name); } /** * Return all the member datasets in the specified group. * @param language * @param name * @return * @throws DaoException */ public List<Dataset> getGroup(Language language, String name) throws DaoException { List<Dataset> members = new ArrayList<Dataset>(); for (String n : groups.get(name)) { members.add(get(language, n)); } return members; } public List<Dataset> getDatasetOrGroup(Language language, String name) throws DaoException { if (isGroup(name)) { return getGroup(language, name); } else { return Arrays.asList(get(language, name)); } } /** * @param name * @return Returns information about the dataset with the specified name. */ public Info getInfo(String name) { for (Info info : this.info) { if (info.name.equalsIgnoreCase(name)) { return info; } } return null; } /** * Sets the internal disambiguator AND marks resolve phrases to true. * @param dab */ public void setDisambiguator(Disambiguator dab) { this.disambiguator = dab; this.resolvePhrases = true; } /** * @param resolvePhrases If true, phrases are resolved to local page ids * The disambiguator MUST be set as well. */ public void setResolvePhrases(boolean resolvePhrases) { this.resolvePhrases = resolvePhrases; if (resolvePhrases && disambiguator == null) { throw new IllegalStateException("resolve phrases et to true, but no disambiguator specified."); } } public void setGroups(Map<String, List<String>> groups) { this.groups = groups; } /** * Reads a dataset from a buffered reader. * @param name Name of the dataset, must end with csv for comma separated files. * @param language Language of the dataset. * @param reader The inputsource of the dataset. * @return The dataset * @throws DaoException */ protected Dataset read(String name, Language language, BufferedReader reader) throws DaoException { List<KnownSim> result = new ArrayList<KnownSim>(); try { String delim = "\t"; if (name.toLowerCase().endsWith("csv")) { delim = ","; } while (true) { String line = reader.readLine(); if (line == null) break; String tokens[] = line.split(delim); if (tokens.length >= 3) { KnownSim ks = new KnownSim( tokens[0], tokens[1], Double.valueOf(tokens[2]), language ); if (resolvePhrases) { LocalId id1 = disambiguator.disambiguateTop(new LocalString(language, ks.phrase1), null); LocalId id2 = disambiguator.disambiguateTop(new LocalString(language, ks.phrase2), null); if (id1 != null) { ks.wpId1 = id1.getId(); } if (id2 != null) { ks.wpId2 = id2.getId(); } } result.add(ks); } else { throw new DaoException("Invalid line in dataset file " + name + ": " + "'" + StringEscapeUtils.escapeJava(line) + "'"); } } reader.close(); } catch (IOException e) { throw new DaoException(e); } Dataset dataset = new Dataset(name, language, result); if (normalize) { dataset.normalize(); } return dataset; } /** * Writes a dataset out to a particular path * @param dataset * @param path * @throws DaoException */ public void write(Dataset dataset, File path) throws DaoException { try { BufferedWriter writer = new BufferedWriter(new FileWriter(path)); String delim = "\t"; for (KnownSim ks: dataset.getData()) { writer.write(ks.phrase1 + delim + ks.phrase2 + delim + ks.similarity + "\n"); } writer.flush(); writer.close(); } catch (IOException e) { throw new DaoException(e); } } /** * Read the embedded info.tsv file in the classpath. * @return * @throws DaoException */ public static Collection<Info> readInfos() throws DaoException { try { return readInfos(WpIOUtils.openResource(RESOURCE_DATASET_INFO)); } catch (IOException e) { throw new DaoException(e); } } /** * Returns information about datasets in a reader. * @param reader * @return * @throws DaoException */ public static Collection<Info> readInfos(BufferedReader reader) throws DaoException { try { List<Info> infos = new ArrayList<Info>(); while (true) { try { String line = reader.readLine(); if (line == null) { break; } String tokens[] = line.trim().split("\t"); infos.add(new Info(tokens[0], new LanguageSet(tokens[1]))); } catch (IOException e) { throw new DaoException(e); } } return infos; } finally { IOUtils.closeQuietly(reader); } } public static class Provider extends org.wikibrain.conf.Provider<DatasetDao> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<DatasetDao> getType() { return DatasetDao.class; } @Override public String getPath() { return "sr.dataset.dao"; } @Override public DatasetDao get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("resource")) { return null; } DatasetDao dao = new DatasetDao(); if (config.hasPath("normalize")) { dao.setNormalize(config.getBoolean("normalize")); } if (config.hasPath("disambig")) { dao.setDisambiguator( getConfigurator().get(Disambiguator.class, config.getString("disambig"))); } if (config.hasPath("resolvePhrases")) { dao.setResolvePhrases(config.getBoolean("resolvePhrases")); } Map<String, List<String>> groups = new HashMap<String, List<String>>(); Config groupConfig = getConfig().get().getConfig("sr.dataset.groups"); for (Map.Entry<String, ConfigValue> entry : groupConfig.entrySet()) { groups.put(entry.getKey(), (List<String>)entry.getValue().unwrapped()); } dao.setGroups(groups); return dao; } } }