package org.wikibrain.sr.utils; import gnu.trove.set.TIntSet; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.sr.normalize.IdentityNormalizer; import org.wikibrain.sr.normalize.Normalizer; import org.wikibrain.utils.ParallelForEach; import org.wikibrain.utils.Procedure; import java.io.*; import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A pair of normalizers (similarity, mostSimilar) persisted to disk. * * @author Shilad Sen */ public class SrNormalizers { private static final Logger LOG = LoggerFactory.getLogger(SrNormalizers.class); public static final String SIMILARITY_NORMALIZER = "similarityNormalizer"; public static final String MOST_SIMILAR_NORMALIZER = "mostSimilarNormalizer"; private Normalizer mostSimilarNormalizer = new IdentityNormalizer(); private Normalizer similarityNormalizer = new IdentityNormalizer(); public SrNormalizers() {} public Normalizer getMostSimilarNormalizer() { return mostSimilarNormalizer; } public Normalizer getSimilarityNormalizer() { return similarityNormalizer; } public void setMostSimilarNormalizer(Normalizer normalizer) { this.mostSimilarNormalizer = normalizer; } public void setSimilarityNormalizer(Normalizer normalizer) { this.similarityNormalizer = normalizer; } public void clear(File dir) { FileUtils.deleteQuietly(new File(dir, MOST_SIMILAR_NORMALIZER)); FileUtils.deleteQuietly(new File(dir, SIMILARITY_NORMALIZER)); } public boolean hasReadableNormalizers(File dir) { return isValidNormalizer(dir, MOST_SIMILAR_NORMALIZER) || isValidNormalizer(dir, SIMILARITY_NORMALIZER); } /** * Reads the noramlizers from disk. * This method expects the files exist and are valid, so hasNormalizers should be called first. * @throws java.io.IOException */ public void read(File dir) throws IOException { if (isValidNormalizer(dir, MOST_SIMILAR_NORMALIZER)) { mostSimilarNormalizer = readNormalizer(dir, MOST_SIMILAR_NORMALIZER); } if (isValidNormalizer(dir, SIMILARITY_NORMALIZER)) { similarityNormalizer = readNormalizer(dir, SIMILARITY_NORMALIZER); } } public void write(File dir) throws IOException { writeNormalizer(dir, MOST_SIMILAR_NORMALIZER, mostSimilarNormalizer); writeNormalizer(dir, SIMILARITY_NORMALIZER, similarityNormalizer); } /** * Returns true if a normalizer exists and it is trained. * @param name * @return */ private boolean isValidNormalizer(File dir, String name) { File path = new File(dir, name); if (!path.isFile()) { return false; } try { return readNormalizer(dir, name).isTrained(); } catch (Exception e) { LOG.warn("Failed to load normalizer at " + path.getAbsolutePath() + ". Setting it to be invalid."); return false; } } /** * * @param metric * @param dataset */ public void trainSimilarity(final SRMetric metric, Dataset dataset) { if (similarityNormalizer instanceof IdentityNormalizer) { return; } if (!dataset.getLanguage().equals(metric.getLanguage())) { throw new IllegalArgumentException("SR metric has language " + metric.getLanguage() + " but dataset has language " + dataset.getLanguage()); } final Normalizer trainee = similarityNormalizer; similarityNormalizer = new IdentityNormalizer(); try { trainee.reset(); ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { public void call(KnownSim ks) throws IOException, DaoException { ks.maybeSwap(); SRResult sim = metric.similarity(ks.phrase1, ks.phrase2, false); trainee.observe(sim == null ? Double.NaN : sim.getScore(), ks.similarity); } }, 100); trainee.observationsFinished(); LOG.info("trained similarity normalizer: " + trainee.dump()); } finally { similarityNormalizer = trainee; } } /** * * @param metric * @param disambiguator * @param dataset * @param validIds * @param maxResults */ public void trainMostSimilar(final SRMetric metric, final Disambiguator disambiguator, Dataset dataset, final TIntSet validIds, final int maxResults) { if (mostSimilarNormalizer instanceof IdentityNormalizer) { return; } if (!dataset.getLanguage().equals(metric.getLanguage())) { throw new IllegalArgumentException("SR metric has language " + metric.getLanguage() + " but dataset has language " + dataset.getLanguage()); } final Normalizer trainee = mostSimilarNormalizer; mostSimilarNormalizer = new IdentityNormalizer(); try { trainee.reset(); ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { public void call(KnownSim ks) throws IOException, DaoException { ks.maybeSwap(); List<LocalString> localStrings = new ArrayList<LocalString>(); localStrings.add(new LocalString(ks.language, ks.phrase1)); localStrings.add(new LocalString(ks.language, ks.phrase2)); List<LocalId> ids = disambiguator.disambiguateTop(localStrings, null); if (ids != null && ids.size() == 2 && ids.get(0) != null && ids.get(1) != null) { LocalId lid1 = ids.get(0); LocalId lid2 = ids.get(1); SRResultList dsl = metric.mostSimilar(lid1.getId(), maxResults, validIds); if (dsl != null) { trainee.observe(dsl, dsl.getIndexForId(lid2.getId()), ks.similarity); } } } }, 100); trainee.observationsFinished(); LOG.info("trained most similar normalizer for " + metric.getName() + ": " + trainee.dump()); } finally { mostSimilarNormalizer = trainee; } } /** * Reads a single normalizer from disk. * @param name * @return * @throws java.io.IOException */ private Normalizer readNormalizer(File dir, String name) throws IOException { ObjectInputStream oip = null; try { oip = new ObjectInputStream(new FileInputStream(new File(dir, name))); return (Normalizer)oip.readObject(); } catch (ClassNotFoundException e) { throw new IllegalStateException(e); // should not happen } catch (FileNotFoundException e) { throw new IllegalStateException(e); // should not happen } finally { if (oip != null) IOUtils.closeQuietly(oip); } } /** * Writes a single normalizer to disk. * @param dir * @param name * @param normalizer * @throws IOException */ private void writeNormalizer(File dir, String name, Normalizer normalizer) throws IOException { ObjectOutputStream oop = new ObjectOutputStream(new FileOutputStream(new File(dir, name))); oop.writeObject(normalizer); oop.flush(); oop.close(); } }