package org.wikipedia.miner.comparison; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; import java.util.Vector; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.CorrelationCalculator; import org.wikipedia.miner.util.ProgressTracker; import weka.classifiers.Classifier; import weka.classifiers.functions.GaussianProcesses; import weka.core.Instance; import org.dmilne.weka.wrapper.Dataset; import org.dmilne.weka.wrapper.Decider; import org.dmilne.weka.wrapper.DeciderBuilder; import org.dmilne.weka.wrapper.InstanceBuilder; public class ConnectionSnippetWeighter { enum Attributes { generality, inLinkCount, outLinkCount, isTopic1, relatednessToTopic1, isTopic2, relatednessToTopic2, sentenceIndex, wordCount, isListItem, isFromFirstParagraph, isAfterHeading } private Wikipedia wikipedia ; private ArticleComparer cmp ; private Decider<Attributes,Double> snippetWeighter ; private Dataset<Attributes,Double> trainingDataset ; @SuppressWarnings("unchecked") public ConnectionSnippetWeighter(Wikipedia wikipedia, ArticleComparer cmp) throws Exception { this.wikipedia = wikipedia ; this.cmp = cmp ; snippetWeighter = (Decider<Attributes, Double>) new DeciderBuilder<Attributes>("connectionSnippetWeighter", Attributes.class) .setDefaultAttributeTypeNumeric() .setAttributeTypeBoolean(Attributes.isTopic1) .setAttributeTypeBoolean(Attributes.isTopic2) .setAttributeTypeBoolean(Attributes.isAfterHeading) .setAttributeTypeBoolean(Attributes.isListItem) .setAttributeTypeBoolean(Attributes.isFromFirstParagraph) .setClassAttributeTypeNumeric("snippetWeight") .build(); if (wikipedia.getConfig().getComparisonSnippetModel() != null) this.loadClassifier(wikipedia.getConfig().getComparisonSnippetModel()) ; } public double getWeight(ConnectionSnippet snippet) throws Exception { if (!snippetWeighter.isReady()) { //Logger.getLogger(ArticleComparer.class).debug("Article comparison without ml") ; //no classifier available, so just return mean of gathered measurements ; double totalWeight = 0 ; totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic1()) ; totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic2()) ; return totalWeight / 2 ; } else { return snippetWeighter.getDecision(getInstance(snippet)) ; } } public void train(Vector<ConnectionSnippet> weightedSnippets) throws Exception { trainingDataset = snippetWeighter.createNewDataset() ; ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "training", ConnectionSnippetWeighter.class) ; for (ConnectionSnippet snippet: weightedSnippets) { if (snippet.getWeight() == null) throw new Exception("Training snippet is not weighted") ; trainingDataset.add(getInstance(snippet)) ; pn.update() ; } } public double test(Vector<ConnectionSnippet> weightedSnippets) throws Exception { List<Double> manualWeights = new ArrayList<Double>() ; List<Double> autoWeights = new ArrayList<Double>() ; ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "testing", ArticleComparer.class) ; for (ConnectionSnippet snippet: weightedSnippets) { if (snippet.getWeight() == null) throw new Exception("Testing snippet is not weighted") ; manualWeights.add(snippet.getWeight()) ; autoWeights.add(this.getWeight(snippet)) ; pn.update() ; } return CorrelationCalculator.getCorrelation(manualWeights, autoWeights) ; } /** * Saves the training data generated by train() to the given file. * The data is saved in WEKA's arff format. * * @param file the file to save the training data to * @throws IOException if the file cannot be written to */ public void saveTrainingData(File file) throws IOException, Exception { trainingDataset.save(file) ; } /** * Loads training data from the given file. * The file must be a valid WEKA arff file. * * @param file the file to save the training data to * @throws IOException if the file cannot be read. * @throws Exception if the file does not contain valid training data. */ public void loadTrainingData(File file) throws IOException, Exception{ trainingDataset.load(file) ; } /** * Serializes the classifer and saves it to the given file. * * @param file the file to save the classifier to * @throws IOException if the file cannot be read */ public void saveClassifier(File file) throws IOException { snippetWeighter.save(file) ; } /** * Loads the classifier from file * * @param file * @throws Exception */ public void loadClassifier(File file) throws Exception { snippetWeighter.load(file) ; } /** * * * @param classifier * @throws Exception */ public void buildClassifier(Classifier classifier) throws Exception { snippetWeighter.train(classifier, trainingDataset) ; } /** * * * @param classifier * @throws Exception */ public void buildDefaultClassifier() throws Exception { Classifier classifier = new GaussianProcesses() ; snippetWeighter.train(classifier, trainingDataset) ; } private Instance getInstance(ConnectionSnippet snippet) throws Exception { InstanceBuilder<Attributes, Double> ib = snippetWeighter.getInstanceBuilder() ; ib.setAttribute(Attributes.generality, snippet.getSource().getGenerality()) ; ib.setAttribute(Attributes.inLinkCount, Math.log(snippet.getSource().getDistinctLinksInCount() +1)) ; ib.setAttribute(Attributes.outLinkCount, Math.log(snippet.getSource().getDistinctLinksOutCount() +1)) ; ib.setAttribute(Attributes.isTopic1, snippet.getSource().getId() == snippet.getTopic1().getId()) ; ib.setAttribute(Attributes.relatednessToTopic1, cmp.getRelatedness(snippet.getSource(), snippet.getTopic1())) ; ib.setAttribute(Attributes.isTopic2, snippet.getSource().getId() == snippet.getTopic2().getId()) ; ib.setAttribute(Attributes.relatednessToTopic2, cmp.getRelatedness(snippet.getSource(), snippet.getTopic2())) ; ib.setAttribute(Attributes.sentenceIndex, snippet.getSentenceIndex()) ; StringTokenizer t = new StringTokenizer(snippet.getPlainText()) ; ib.setAttribute(Attributes.wordCount, t.countTokens()) ; ib.setAttribute(Attributes.isListItem, snippet.isListItem()) ; ib.setAttribute(Attributes.isAfterHeading, snippet.followsHeading()) ; if (snippet.getWeight() != null) ib.setClassAttribute(snippet.getWeight()) ; return ib.build() ; } }