package org.wikipedia.miner.comparison;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Vector;
import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.CorrelationCalculator;
import org.wikipedia.miner.util.ProgressTracker;
import weka.classifiers.Classifier;
import weka.classifiers.functions.GaussianProcesses;
import weka.core.Instance;
import org.dmilne.weka.wrapper.Dataset;
import org.dmilne.weka.wrapper.Decider;
import org.dmilne.weka.wrapper.DeciderBuilder;
import org.dmilne.weka.wrapper.InstanceBuilder;
public class ConnectionSnippetWeighter {
enum Attributes {
generality,
inLinkCount,
outLinkCount,
isTopic1,
relatednessToTopic1,
isTopic2,
relatednessToTopic2,
sentenceIndex,
wordCount,
isListItem,
isFromFirstParagraph,
isAfterHeading
}
private Wikipedia wikipedia ;
private ArticleComparer cmp ;
private Decider<Attributes,Double> snippetWeighter ;
private Dataset<Attributes,Double> trainingDataset ;
@SuppressWarnings("unchecked")
public ConnectionSnippetWeighter(Wikipedia wikipedia, ArticleComparer cmp) throws Exception {
this.wikipedia = wikipedia ;
this.cmp = cmp ;
snippetWeighter = (Decider<Attributes, Double>) new DeciderBuilder<Attributes>("connectionSnippetWeighter", Attributes.class)
.setDefaultAttributeTypeNumeric()
.setAttributeTypeBoolean(Attributes.isTopic1)
.setAttributeTypeBoolean(Attributes.isTopic2)
.setAttributeTypeBoolean(Attributes.isAfterHeading)
.setAttributeTypeBoolean(Attributes.isListItem)
.setAttributeTypeBoolean(Attributes.isFromFirstParagraph)
.setClassAttributeTypeNumeric("snippetWeight")
.build();
if (wikipedia.getConfig().getComparisonSnippetModel() != null)
this.loadClassifier(wikipedia.getConfig().getComparisonSnippetModel()) ;
}
public double getWeight(ConnectionSnippet snippet) throws Exception {
if (!snippetWeighter.isReady()) {
//Logger.getLogger(ArticleComparer.class).debug("Article comparison without ml") ;
//no classifier available, so just return mean of gathered measurements ;
double totalWeight = 0 ;
totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic1()) ;
totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic2()) ;
return totalWeight / 2 ;
} else {
return snippetWeighter.getDecision(getInstance(snippet)) ;
}
}
public void train(Vector<ConnectionSnippet> weightedSnippets) throws Exception {
trainingDataset = snippetWeighter.createNewDataset() ;
ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "training", ConnectionSnippetWeighter.class) ;
for (ConnectionSnippet snippet: weightedSnippets) {
if (snippet.getWeight() == null)
throw new Exception("Training snippet is not weighted") ;
trainingDataset.add(getInstance(snippet)) ;
pn.update() ;
}
}
public double test(Vector<ConnectionSnippet> weightedSnippets) throws Exception {
List<Double> manualWeights = new ArrayList<Double>() ;
List<Double> autoWeights = new ArrayList<Double>() ;
ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "testing", ArticleComparer.class) ;
for (ConnectionSnippet snippet: weightedSnippets) {
if (snippet.getWeight() == null)
throw new Exception("Testing snippet is not weighted") ;
manualWeights.add(snippet.getWeight()) ;
autoWeights.add(this.getWeight(snippet)) ;
pn.update() ;
}
return CorrelationCalculator.getCorrelation(manualWeights, autoWeights) ;
}
/**
* Saves the training data generated by train() to the given file.
* The data is saved in WEKA's arff format.
*
* @param file the file to save the training data to
* @throws IOException if the file cannot be written to
*/
public void saveTrainingData(File file) throws IOException, Exception {
trainingDataset.save(file) ;
}
/**
* Loads training data from the given file.
* The file must be a valid WEKA arff file.
*
* @param file the file to save the training data to
* @throws IOException if the file cannot be read.
* @throws Exception if the file does not contain valid training data.
*/
public void loadTrainingData(File file) throws IOException, Exception{
trainingDataset.load(file) ;
}
/**
* Serializes the classifer and saves it to the given file.
*
* @param file the file to save the classifier to
* @throws IOException if the file cannot be read
*/
public void saveClassifier(File file) throws IOException {
snippetWeighter.save(file) ;
}
/**
* Loads the classifier from file
*
* @param file
* @throws Exception
*/
public void loadClassifier(File file) throws Exception {
snippetWeighter.load(file) ;
}
/**
*
*
* @param classifier
* @throws Exception
*/
public void buildClassifier(Classifier classifier) throws Exception {
snippetWeighter.train(classifier, trainingDataset) ;
}
/**
*
*
* @param classifier
* @throws Exception
*/
public void buildDefaultClassifier() throws Exception {
Classifier classifier = new GaussianProcesses() ;
snippetWeighter.train(classifier, trainingDataset) ;
}
private Instance getInstance(ConnectionSnippet snippet) throws Exception {
InstanceBuilder<Attributes, Double> ib = snippetWeighter.getInstanceBuilder() ;
ib.setAttribute(Attributes.generality, snippet.getSource().getGenerality()) ;
ib.setAttribute(Attributes.inLinkCount, Math.log(snippet.getSource().getDistinctLinksInCount() +1)) ;
ib.setAttribute(Attributes.outLinkCount, Math.log(snippet.getSource().getDistinctLinksOutCount() +1)) ;
ib.setAttribute(Attributes.isTopic1, snippet.getSource().getId() == snippet.getTopic1().getId()) ;
ib.setAttribute(Attributes.relatednessToTopic1, cmp.getRelatedness(snippet.getSource(), snippet.getTopic1())) ;
ib.setAttribute(Attributes.isTopic2, snippet.getSource().getId() == snippet.getTopic2().getId()) ;
ib.setAttribute(Attributes.relatednessToTopic2, cmp.getRelatedness(snippet.getSource(), snippet.getTopic2())) ;
ib.setAttribute(Attributes.sentenceIndex, snippet.getSentenceIndex()) ;
StringTokenizer t = new StringTokenizer(snippet.getPlainText()) ;
ib.setAttribute(Attributes.wordCount, t.countTokens()) ;
ib.setAttribute(Attributes.isListItem, snippet.isListItem()) ;
ib.setAttribute(Attributes.isAfterHeading, snippet.followsHeading()) ;
if (snippet.getWeight() != null)
ib.setClassAttribute(snippet.getWeight()) ;
return ib.build() ;
}
}