package org.wikipedia.miner.annotation.weighting;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.log4j.Logger;
import org.wikipedia.miner.annotation.Topic;
import org.wikipedia.miner.annotation.TopicDetector;
import org.wikipedia.miner.annotation.preprocessing.PreprocessedDocument;
import org.wikipedia.miner.comparison.ArticleComparer;
import org.wikipedia.miner.model.Article;
import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.ProgressTracker;
import org.wikipedia.miner.util.RelatednessCache;
import org.wikipedia.miner.util.Result;
import org.wikipedia.miner.util.TopicIndexingSet;
import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.Instance;
import weka.core.Utils;
import weka.core.WekaException;
import org.dmilne.weka.wrapper.Dataset;
import org.dmilne.weka.wrapper.Decider;
import org.dmilne.weka.wrapper.DeciderBuilder;
import org.dmilne.weka.wrapper.InstanceBuilder;
public class TopicIndexer extends TopicWeighter {
private Wikipedia wikipedia ;
private enum Attributes {occurances,maxDisambigConfidence,avgDisambigConfidence,relatednessToContext,relatednessToOtherTopics,maxLinkProbability,avgLinkProbability,generality,firstOccurance,lastOccurance,spread} ;
private Decider<Attributes, Boolean> decider ;
private Dataset<Attributes, Boolean> dataset ;
int candidatesConsidered = 0 ;
public TopicIndexer(Wikipedia wikipedia) throws Exception {
this.wikipedia = wikipedia ;
decider = (Decider<Attributes, Boolean>) new DeciderBuilder<Attributes>("LinkDisambiguator", Attributes.class)
.setDefaultAttributeTypeNumeric()
.setClassAttributeTypeBoolean("isKeyTopic")
.build();
}
public int getCandidatesConsidered() {
return candidatesConsidered ;
}
public HashMap<Integer,Double> getTopicWeights(Collection<Topic> topics) throws Exception {
if (!decider.isReady())
throw new WekaException("You must build (or load) classifier first.") ;
HashMap<Integer, Double> topicWeights = new HashMap<Integer, Double>() ;
for (Topic topic: topics) {
Instance i = getInstance(topic, null) ;
double prob = decider.getDecisionDistribution(i).get(true) ;
topicWeights.put(topic.getId(), prob) ;
candidatesConsidered++ ;
}
return topicWeights ;
}
public void train(TopicIndexingSet trainingSet, String datasetName, TopicDetector td) throws Exception{
dataset = decider.createNewDataset();
ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "training", TopicIndexer.class) ;
for (TopicIndexingSet.Item i: trainingSet) {
train(i, td) ;
tracker.update() ;
}
weightTrainingInstances() ;
}
public Result<Integer> test(TopicIndexingSet trainingSet, TopicDetector td) throws Exception{
if (!decider.isReady())
throw new Exception("You must build (or load) classifier first.") ;
double worstRecall = 1 ;
double worstPrecision = 1 ;
int docsTested = 0 ;
int perfectRecall = 0 ;
int perfectPrecision = 0 ;
Result<Integer> r = new Result<Integer>() ;
ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "Testing", TopicIndexer.class) ;
for (TopicIndexingSet.Item item:trainingSet) {
docsTested ++ ;
Result<Integer> ir = test(item, td) ;
if (ir.getRecall() ==1) perfectRecall++ ;
if (ir.getPrecision() == 1) perfectPrecision++ ;
worstRecall = Math.min(worstRecall, ir.getRecall()) ;
worstPrecision = Math.min(worstPrecision, ir.getPrecision()) ;
r.addIntermediateResult(ir) ;
tracker.update() ;
}
System.out.println("worstR:" + worstRecall + ", worstP:" + worstPrecision) ;
System.out.println("tested:" + docsTested + ", perfectR:" + perfectRecall + ", perfectP:" + perfectPrecision) ;
return r ;
}
/**
* Saves the training data generated by train() to the given file.
* The data is saved in WEKA's arff format.
*
* @param file the file to save the training data to
* @throws IOException if the file cannot be written to
*/
@SuppressWarnings("unchecked")
public void saveTrainingData(File file) throws Exception {
Logger.getLogger(TopicIndexer.class).info("saving training data") ;
dataset.save(file) ;
}
/**
* Loads training data from the given file.
* The file must be a valid WEKA arff file.
*
* @param file the file to save the training data to
* @throws IOException if the file cannot be read.
* @throws Exception if the file does not contain valid training data.
*/
public void loadTrainingData(File file) throws Exception{
Logger.getLogger(TopicIndexer.class).info("loading training data") ;
dataset.load(file) ;
weightTrainingInstances() ;
}
public void clearTrainingData() {
dataset = null ;
}
/**
* Serializes the classifer and saves it to the given file.
*
* @param file the file to save the classifier to
* @throws IOException if the file cannot be read
*/
public void saveClassifier(File file) throws IOException {
Logger.getLogger(TopicIndexer.class).info("saving classifier") ;
decider.save(file) ;
}
/**
* Loads the classifier from file
*
* @param file
* @throws Exception
*/
public void loadClassifier(File file) throws Exception {
Logger.getLogger(TopicIndexer.class).info("loading classifier") ;
decider.load(file) ;
}
/**
*
*
* @param classifier
* @throws Exception
*/
public void buildClassifier(Classifier classifier) throws Exception {
Logger.getLogger(TopicIndexer.class).info("building classifier") ;
decider.train(classifier, dataset) ;
}
public void buildDefaultClassifier() throws Exception {
Logger.getLogger(TopicIndexer.class).info("building classifier") ;
Classifier classifier = new Bagging() ;
classifier.setOptions(Utils.splitOptions("-P 10 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -U -M 2")) ;
decider.train(classifier, dataset) ;
}
private void train(TopicIndexingSet.Item item, TopicDetector td) throws Exception{
RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ;
Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ;
for (Topic topic: topics)
dataset.add(getInstance(topic, item.isTopic(topic))) ;
}
private Result<Integer> test(TopicIndexingSet.Item item, TopicDetector td) throws Exception{
RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ;
Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ;
ArrayList<Topic> weightedTopics = this.getWeightedTopics(topics) ;
HashSet<Integer> autoIds = new HashSet<Integer>() ;
for (Topic topic: weightedTopics) {
if (topic.getWeight() > 0.5)
autoIds.add(topic.getId()) ;
}
Result<Integer> result = new Result<Integer>(autoIds,item.getTopicIds()) ;
System.out.println(" - " + result) ;
return result ;
}
private Instance getInstance(Topic topic, Boolean isKeyTopic) throws Exception {
InstanceBuilder<Attributes,Boolean> ib = decider.getInstanceBuilder()
.setAttribute(Attributes.occurances, topic.getOccurances())
.setAttribute(Attributes.maxDisambigConfidence, topic.getMaxDisambigConfidence())
.setAttribute(Attributes.avgDisambigConfidence, topic.getAverageDisambigConfidence())
.setAttribute(Attributes.relatednessToContext, topic.getRelatednessToContext())
.setAttribute(Attributes.relatednessToOtherTopics, topic.getRelatednessToOtherTopics())
.setAttribute(Attributes.maxLinkProbability, topic.getMaxLinkProbability())
.setAttribute(Attributes.avgLinkProbability, topic.getAverageLinkProbability())
.setAttribute(Attributes.generality, topic.getGenerality())
.setAttribute(Attributes.firstOccurance, topic.getFirstOccurance())
.setAttribute(Attributes.lastOccurance, topic.getLastOccurance())
.setAttribute(Attributes.spread, topic.getSpread()) ;
if (isKeyTopic != null)
ib = ib.setClassAttribute(isKeyTopic) ;
return ib.build() ;
}
//TODO: this should really be refactored as a separate filter
@SuppressWarnings("unchecked")
private void weightTrainingInstances() {
double positiveInstances = 0 ;
double negativeInstances = 0 ;
Enumeration<Instance> e = dataset.enumerateInstances() ;
while (e.hasMoreElements()) {
Instance i = (Instance) e.nextElement() ;
double isValidSense = i.value(3) ;
if (isValidSense == 0)
positiveInstances ++ ;
else
negativeInstances ++ ;
}
double p = (double) positiveInstances / (positiveInstances + negativeInstances) ;
e = dataset.enumerateInstances() ;
while (e.hasMoreElements()) {
Instance i = (Instance) e.nextElement() ;
double isValidSense = i.value(3) ;
if (isValidSense == 0)
i.setWeight(0.5 * (1.0/p)) ;
else
i.setWeight(0.5 * (1.0/(1-p))) ;
}
}
}