package org.wikipedia.miner.comparison; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.EnumSet; import org.apache.log4j.Logger; import org.wikipedia.miner.db.WDatabase.DatabaseType; import org.wikipedia.miner.db.WEnvironment.StatisticName; import org.wikipedia.miner.db.struct.DbIntList; import org.wikipedia.miner.model.Article; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.CorrelationCalculator; import org.wikipedia.miner.util.ProgressTracker; import org.wikipedia.miner.util.WikipediaConfiguration; import org.dmilne.weka.wrapper.*; import weka.classifiers.Classifier; import weka.classifiers.functions.GaussianProcesses; import weka.core.Instance; public class ArticleComparer { /** * Data used to generate article relatedness measures. * */ public enum DataDependency { /** * Use links made to articles to measure relatedness. You should cache {@link DatabaseType#pageLinksIn} if using this mode extensively. */ pageLinksIn, /** * Use links made from articles to measure relatedness. You should cache {@link DatabaseType#pageLinksOut} and {@link DatabaseType#pageLinkCounts} if using this mode extensively. */ pageLinksOut, /** * Use link counts to measure relatedness. You should cache {@link DatabaseType#pageLinkCounts} if using this mode extensively. */ linkCounts } ; private enum LinkDirection{In, Out} ; Wikipedia wikipedia ; EnumSet<DataDependency> dependancies ; int wikipediaArticleCount ; Double m ; private long articlesCompared = 0 ; enum Attributes { inLinkGoogleMeasure, //inLinkUnion, inLinkIntersection, inLinkVectorMeasure, outLinkGoogleMeasure, //outLinkUnion, outLinkIntersection, outLinkVectorMeasure, } Decider<Attributes,Double> relatednessMeasurer ; Dataset<Attributes,Double> trainingDataset ; //DoublePredictorOld relatednessMeasurer ; public ArticleComparer(Wikipedia wikipedia) throws Exception { WikipediaConfiguration conf = wikipedia.getConfig() ; if (conf.getArticleComparisonDependancies() == null) throw new Exception("The given wikipedia configuration does not specify default article comparison data dependancies"); init(wikipedia, conf.getArticleComparisonDependancies()) ; } public ArticleComparer(Wikipedia wikipedia, EnumSet<DataDependency> dependancies) throws Exception { init(wikipedia, dependancies) ; } @SuppressWarnings("unchecked") private void init(Wikipedia wikipedia, EnumSet<DataDependency> dependancies) throws Exception { if (!dependancies.contains(DataDependency.pageLinksIn) && !dependancies.contains(DataDependency.pageLinksOut)) throw new Exception("Dependancies must include at least pageLinksIn or pageLinksOut") ; this.wikipedia = wikipedia ; this.dependancies = dependancies ; wikipediaArticleCount = new Long(wikipedia.getEnvironment().retrieveStatistic(StatisticName.articleCount)).intValue() ; m = Math.log(wikipediaArticleCount) ; relatednessMeasurer = (Decider<Attributes, Double>) new DeciderBuilder<Attributes>("articleComparer", Attributes.class) .setDefaultAttributeTypeNumeric() .setClassAttributeTypeNumeric("relatedness") .build(); if (wikipedia.getConfig().getArticleComparisonModel() != null) this.loadClassifier(wikipedia.getConfig().getArticleComparisonModel()) ; } public Double getRelatedness(Article artA, Article artB) throws Exception { if (artA.getId() == artB.getId()) return 1.0 ; ArticleComparison cmp = getComparison(artA, artB) ; if (cmp == null) return 0.0 ; if ( (cmp.getInLinkIntersectionProportion() == null || cmp.getInLinkIntersectionProportion() == 0) && (cmp.getOutLinkIntersectionProportion() == null || cmp.getOutLinkIntersectionProportion()==0) ) return 0.0 ; //System.out.println("gi " + cmp.getInLinkGoogleMeasure()) ; ///System.out.println("go " + cmp.getOutLinkGoogleMeasure()) ; //System.out.println("ti " + cmp.getInLinkVectorMeasure()) ; //System.out.println("to " + cmp.getOutLinkVectorMeasure()) ; if (!relatednessMeasurer.isReady()) { //Logger.getLogger(ArticleComparer.class).debug("Article comparison without ml") ; //no classifier available, so just return mean of gathered measurements ; int count = 0 ; double total = 0 ; if (dependancies.contains(DataDependency.pageLinksIn)) { count++ ; total = total + cmp.getInLinkGoogleMeasure() ; if (dependancies.contains(DataDependency.linkCounts)) { count++ ; total = total + cmp.getInLinkVectorMeasure() ; } } if (dependancies.contains(DataDependency.pageLinksOut)) { count++ ; total = total + cmp.getOutLinkGoogleMeasure() ; if (dependancies.contains(DataDependency.linkCounts)) { count++ ; total = total + cmp.getOutLinkVectorMeasure() ; } } if (count == 0) return 0.0 ; else return total/count ; } else { return relatednessMeasurer.getDecision(getInstance(cmp, null)) ; } } public void train(ComparisonDataSet dataset) throws Exception { trainingDataset = relatednessMeasurer.createNewDataset() ; ProgressTracker pn = new ProgressTracker(dataset.getItems().size(), "training", ArticleComparer.class) ; for (ComparisonDataSet.Item item: dataset.getItems()) { if (item.getIdA() < 0 || item.getIdB() < 0) continue ; Article artA = null; try{ artA = new Article(wikipedia.getEnvironment(), item.getIdA()) ; } catch (Exception e) { Logger.getLogger(ArticleComparer.class).warn(item.getIdA() + " is not a valid article") ; } Article artB = null; try{ artB = new Article(wikipedia.getEnvironment(), item.getIdB()) ; } catch (Exception e) { Logger.getLogger(ArticleComparer.class).warn(item.getIdB() + " is not a valid article") ; } if (artA != null && artB != null) train(artA, artB, item.getRelatedness()) ; pn.update() ; } } /** * Saves the training data generated by train() to the given file. * The data is saved in WEKA's arff format. * * @param file the file to save the training data to * @throws IOException if the file cannot be written to */ public void saveTrainingData(File file) throws IOException, Exception { trainingDataset.save(file) ; } /** * Loads training data from the given file. * The file must be a valid WEKA arff file. * * @param file the file to save the training data to * @throws IOException if the file cannot be read. * @throws Exception if the file does not contain valid training data. */ public void loadTrainingData(File file) throws IOException, Exception{ trainingDataset = relatednessMeasurer.createNewDataset() ; trainingDataset.load(file) ; } /** * Serializes the classifer and saves it to the given file. * * @param file the file to save the classifier to * @throws IOException if the file cannot be read */ public void saveClassifier(File file) throws IOException { relatednessMeasurer.save(file) ; } /** * Loads the classifier from file * * @param file * @throws Exception */ public void loadClassifier(File file) throws Exception { relatednessMeasurer.load(file) ; } /** * * * @param classifier * @throws Exception */ public void buildClassifier(Classifier classifier) throws Exception { relatednessMeasurer.train(classifier, trainingDataset) ; } /** * * * @param classifier * @throws Exception */ public void buildDefaultClassifier() throws Exception { Classifier classifier = new GaussianProcesses() ; relatednessMeasurer.train(classifier, trainingDataset) ; } public double test(ComparisonDataSet dataset) throws Exception { ArrayList<Double> manualMeasures = new ArrayList<Double>() ; ArrayList<Double> autoMeasures = new ArrayList<Double>() ; ProgressTracker pn = new ProgressTracker(dataset.getItems().size(), "testing", ArticleComparer.class) ; for (ComparisonDataSet.Item item: dataset.getItems()) { if (item.getIdA() < 0 || item.getIdB() < 0) { Logger.getLogger(ArticleComparer.class).info("- ignoring " + item.getIdA() + ":" + item.getTermA() + " vs. " + item.getIdB() + ":" + item.getTermB()) ; continue ; } Article artA = null; try{ artA = new Article(wikipedia.getEnvironment(), item.getIdA()) ; } catch (Exception e) { Logger.getLogger(ArticleComparer.class).warn(item.getIdA() + " is not a valid article") ; } Article artB = null; try{ artB = new Article(wikipedia.getEnvironment(), item.getIdB()) ; } catch (Exception e) { Logger.getLogger(ArticleComparer.class).warn(item.getIdB() + " is not a valid article") ; } if (artA != null && artB != null) { manualMeasures.add(item.getRelatedness()) ; autoMeasures.add(this.getRelatedness(artA, artB)) ; } pn.update() ; } return CorrelationCalculator.getCorrelation(manualMeasures, autoMeasures) ; } private void train(Article artA, Article artB, double relatedness) throws Exception { ArticleComparison cmp = getComparison(artA, artB) ; if (cmp == null) return ; trainingDataset.add(getInstance(cmp, relatedness)) ; } public ArticleComparison getComparison(Article artA, Article artB) { ArticleComparison cmp = new ArticleComparison(artA, artB) ; if (dependancies.contains(DataDependency.pageLinksIn)) cmp = setPageLinkFeatures(cmp, LinkDirection.In, dependancies.contains(DataDependency.linkCounts)) ; if (dependancies.contains(DataDependency.pageLinksOut)) cmp = setPageLinkFeatures(cmp, LinkDirection.Out, dependancies.contains(DataDependency.linkCounts)) ; if (!cmp.inLinkFeaturesSet() && !cmp.outLinkFeaturesSet()) return null ; return cmp ; } // names of all parameters make sense if we assume dir is out private ArticleComparison setPageLinkFeatures(ArticleComparison cmp, LinkDirection dir, boolean useLinkCounts) { //don't gather training or testing data when articles are the same: this screws up normalization if (cmp.getArticleA().getId() == cmp.getArticleB().getId()) return cmp ; ArrayList<Integer> linksA = getLinks(cmp.getArticleA().getId(), dir) ; ArrayList<Integer> linksB = getLinks(cmp.getArticleB().getId(), dir) ; //we can't do anything if there are no links if (linksA.isEmpty() || linksB.isEmpty()) return cmp ; int intersection = 0 ; //int sentenceIntersection = 0 ; int union = 0 ; int indexA = 0 ; int indexB = 0 ; ArrayList<Double> vectA = new ArrayList<Double>() ; ArrayList<Double> vectB = new ArrayList<Double>() ; //get denominators for link frequency Integer linksFromSourceA = 0 ; Integer linksFromSourceB = 0 ; if (useLinkCounts) { if (dir == LinkDirection.Out) { linksFromSourceA = cmp.getArticleA().getTotalLinksOutCount() ; linksFromSourceB = cmp.getArticleB().getTotalLinksOutCount() ; } else { linksFromSourceA = cmp.getArticleA().getTotalLinksInCount() ; linksFromSourceB = cmp.getArticleB().getTotalLinksInCount() ; } } while (indexA < linksA.size() || indexB < linksB.size()) { //identify which links to use (A, B, or both) boolean useA = false; boolean useB = false; boolean mutual = false ; Integer linkA = null ; Integer linkB = null ; Article linkArt ; if (indexA < linksA.size()) linkA = linksA.get(indexA) ; if (indexB < linksB.size()) linkB = linksB.get(indexB) ; if (linkA != null && linkB != null && linkA.equals(linkB)) { useA = true ; useB = true ; linkArt = new Article(wikipedia.getEnvironment(), linkA) ; intersection ++ ; //if (hasSentenceIntersection(linkA.getSentenceIndexes(), linkB.getSentenceIndexes())) // sentenceIntersection++ ; } else { if (linkA != null && (linkB == null || linkA < linkB)) { useA = true ; linkArt = new Article(wikipedia.getEnvironment(), linkA) ; if (linkA.equals(cmp.getArticleB().getId())) { intersection++ ; mutual = true ; } } else { useB = true ; linkArt = new Article(wikipedia.getEnvironment(), linkB) ; if (linkB.equals(cmp.getArticleA().getId())) { intersection++ ; mutual = true ; } } } union ++ ; if (useLinkCounts) { //calculate lfiaf values for each vector int linksToTarget ; if (dir == LinkDirection.Out) linksToTarget = linkArt.getTotalLinksInCount() ; else linksToTarget = linkArt.getTotalLinksOutCount() ; double valA = 0 ; double valB = 0 ; if (mutual) { valA = 1 ; valB = 1 ; } else { if (useA) valA = getLfiaf(1, linksFromSourceA, linksToTarget) ; if (useB) valB = getLfiaf(1, linksFromSourceB, linksToTarget) ; } vectA.add(valA) ; vectB.add(valB) ; } if (useA) indexA++ ; if (useB) indexB++ ; } //calculate google distance inspired measure Double googleMeasure = null ; if (intersection == 0) { googleMeasure = 1.0 ; } else { double a = Math.log(linksA.size()) ; double b = Math.log(linksB.size()) ; double ab = Math.log(intersection) ; googleMeasure = (Math.max(a, b) -ab) / (m - Math.min(a, b)) ; } googleMeasure = ArticleComparison.normalizeGoogleMeasure(googleMeasure) ; //calculate vector (tfidf) inspired measure Double vectorMeasure = null ; if (useLinkCounts) { if (vectA.isEmpty() || vectB.isEmpty()) vectorMeasure = Math.PI/2 ; else { double dotProduct = 0 ; double magnitudeA = 0 ; double magnitudeB = 0 ; //StringBuffer strA = new StringBuffer() ; //StringBuffer strB = new StringBuffer() ; for (int x=0;x<vectA.size();x++) { double valA = vectA.get(x) ; double valB = vectB.get(x) ; /* if (valA > 0) strA.append(df.format(valA) + "\t") ; else strA.append("-.---\t") ; if (valB > 0) strB.append(df.format(valB)+ "\t") ; else strB.append("-.---\t") ; */ dotProduct = dotProduct + (valA * valB) ; magnitudeA = magnitudeA + (valA * valA) ; magnitudeB = magnitudeB + (valB * valB) ; } magnitudeA = Math.sqrt(magnitudeA) ; magnitudeB = Math.sqrt(magnitudeB) ; vectorMeasure = Math.acos(dotProduct / (magnitudeA * magnitudeB)) ; if (vectorMeasure.isNaN()) vectorMeasure = Math.PI/2 ; vectorMeasure = ArticleComparison.normalizeVectorMeasure(vectorMeasure) ; } //if (vectorMeasure.isNaN()) { // System.out.println("A: (" + cmp.getArticleA() + ") " + strA) ; // System.out.println("B: (" + cmp.getArticleB() + ") " + strB) ; //} } double intersectionProportion ; if (union == 0) intersectionProportion = 0 ; else intersectionProportion = (double)intersection/union ; //System.out.println("Intersection: " + intersection + "\tUnion: " + union) ; //System.out.println("Relatedness:" + df.format(googleMeasure)) ; //System.out.println(); if (dir == LinkDirection.Out) cmp.setOutLinkFeatures(googleMeasure, vectorMeasure, union, intersectionProportion) ; else cmp.setInLinkFeatures(googleMeasure, vectorMeasure, union, intersectionProportion) ; return cmp ; } private ArrayList<Integer> getLinks(int artId, LinkDirection dir) { DbIntList ids ; if (dir == LinkDirection.In) ids = wikipedia.getEnvironment().getDbPageLinkInNoSentences().retrieve(artId) ; else ids = wikipedia.getEnvironment().getDbPageLinkOutNoSentences().retrieve(artId) ; if (ids == null || ids.getValues() == null) return new ArrayList<Integer>() ; return ids.getValues() ; } private double getLfiaf(int linksFromSourceToTarget, int linksFromSource, int linksToTarget) { if (linksFromSourceToTarget == 0 || linksFromSource == 0) return 0 ; double linkFreq = (double)linksFromSourceToTarget/linksToTarget ; double inverseArtFreq = Math.log(wikipediaArticleCount/linksFromSource) ; return linkFreq * inverseArtFreq ; } private double wrapMissingValue(Number val) { if (val == null) return Instance.missingValue() ; else return val.doubleValue() ; } private Instance getInstance(ArticleComparison cmp, Double relatedness) throws ClassMissingException, AttributeMissingException { InstanceBuilder<Attributes, Double> ib = relatednessMeasurer.getInstanceBuilder() ; if (dependancies.contains(DataDependency.pageLinksIn)) { ib.setAttribute(Attributes.inLinkGoogleMeasure, cmp.getInLinkGoogleMeasure()) ; //ib.setAttribute(Attributes.inLinkUnion, cmp.getInLinkUnion()) ; ib.setAttribute(Attributes.inLinkIntersection, cmp.getInLinkIntersectionProportion()) ; if (dependancies.contains(DataDependency.linkCounts)) ib.setAttribute(Attributes.inLinkVectorMeasure, cmp.getInLinkVectorMeasure()) ; else ib.setAttribute(Attributes.inLinkVectorMeasure, 0) ; } else { ib.setAttribute(Attributes.inLinkGoogleMeasure, 0) ; //ib.setAttribute(Attributes.inLinkUnion, 0) ; ib.setAttribute(Attributes.inLinkIntersection, 0) ; ib.setAttribute(Attributes.inLinkVectorMeasure, 0) ; } if (dependancies.contains(DataDependency.pageLinksOut)) { ib.setAttribute(Attributes.outLinkGoogleMeasure, wrapMissingValue(cmp.getOutLinkGoogleMeasure())) ; //ib.setAttribute(Attributes.outLinkUnion, wrapMissingValue(cmp.getOutLinkUnion())) ; ib.setAttribute(Attributes.outLinkIntersection, wrapMissingValue(cmp.getOutLinkIntersectionProportion())) ; if (dependancies.contains(DataDependency.linkCounts)) ib.setAttribute(Attributes.outLinkVectorMeasure, wrapMissingValue(cmp.getOutLinkVectorMeasure())) ; else ib.setAttribute(Attributes.outLinkVectorMeasure, 0) ; } else { ib.setAttribute(Attributes.outLinkGoogleMeasure, 0) ; //ib.setAttribute(Attributes.outLinkUnion, 0) ; ib.setAttribute(Attributes.outLinkIntersection, 0) ; ib.setAttribute(Attributes.outLinkVectorMeasure, 0) ; } if (relatedness != null) ib.setClassAttribute(relatedness) ; ib.replaceAllMissingValuesWith(0.0) ; return ib.build() ; } }