package org.wikipedia.miner.annotation.weighting; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import org.apache.log4j.Logger; import org.wikipedia.miner.annotation.Topic; import org.wikipedia.miner.annotation.TopicDetector; import org.wikipedia.miner.annotation.preprocessing.PreprocessedDocument; import org.wikipedia.miner.comparison.ArticleComparer; import org.wikipedia.miner.model.Article; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.ProgressTracker; import org.wikipedia.miner.util.RelatednessCache; import org.wikipedia.miner.util.Result; import org.wikipedia.miner.util.TopicIndexingSet; import weka.classifiers.Classifier; import weka.classifiers.meta.Bagging; import weka.core.Instance; import weka.core.Utils; import weka.core.WekaException; import org.dmilne.weka.wrapper.Dataset; import org.dmilne.weka.wrapper.Decider; import org.dmilne.weka.wrapper.DeciderBuilder; import org.dmilne.weka.wrapper.InstanceBuilder; public class TopicIndexer extends TopicWeighter { private Wikipedia wikipedia ; private enum Attributes {occurances,maxDisambigConfidence,avgDisambigConfidence,relatednessToContext,relatednessToOtherTopics,maxLinkProbability,avgLinkProbability,generality,firstOccurance,lastOccurance,spread} ; private Decider<Attributes, Boolean> decider ; private Dataset<Attributes, Boolean> dataset ; int candidatesConsidered = 0 ; public TopicIndexer(Wikipedia wikipedia) throws Exception { this.wikipedia = wikipedia ; decider = (Decider<Attributes, Boolean>) new DeciderBuilder<Attributes>("LinkDisambiguator", Attributes.class) .setDefaultAttributeTypeNumeric() .setClassAttributeTypeBoolean("isKeyTopic") .build(); } public int getCandidatesConsidered() { return candidatesConsidered ; } public HashMap<Integer,Double> getTopicWeights(Collection<Topic> topics) throws Exception { if (!decider.isReady()) throw new WekaException("You must build (or load) classifier first.") ; HashMap<Integer, Double> topicWeights = new HashMap<Integer, Double>() ; for (Topic topic: topics) { Instance i = getInstance(topic, null) ; double prob = decider.getDecisionDistribution(i).get(true) ; topicWeights.put(topic.getId(), prob) ; candidatesConsidered++ ; } return topicWeights ; } public void train(TopicIndexingSet trainingSet, String datasetName, TopicDetector td) throws Exception{ dataset = decider.createNewDataset(); ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "training", TopicIndexer.class) ; for (TopicIndexingSet.Item i: trainingSet) { train(i, td) ; tracker.update() ; } weightTrainingInstances() ; } public Result<Integer> test(TopicIndexingSet trainingSet, TopicDetector td) throws Exception{ if (!decider.isReady()) throw new Exception("You must build (or load) classifier first.") ; double worstRecall = 1 ; double worstPrecision = 1 ; int docsTested = 0 ; int perfectRecall = 0 ; int perfectPrecision = 0 ; Result<Integer> r = new Result<Integer>() ; ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "Testing", TopicIndexer.class) ; for (TopicIndexingSet.Item item:trainingSet) { docsTested ++ ; Result<Integer> ir = test(item, td) ; if (ir.getRecall() ==1) perfectRecall++ ; if (ir.getPrecision() == 1) perfectPrecision++ ; worstRecall = Math.min(worstRecall, ir.getRecall()) ; worstPrecision = Math.min(worstPrecision, ir.getPrecision()) ; r.addIntermediateResult(ir) ; tracker.update() ; } System.out.println("worstR:" + worstRecall + ", worstP:" + worstPrecision) ; System.out.println("tested:" + docsTested + ", perfectR:" + perfectRecall + ", perfectP:" + perfectPrecision) ; return r ; } /** * Saves the training data generated by train() to the given file. * The data is saved in WEKA's arff format. * * @param file the file to save the training data to * @throws IOException if the file cannot be written to */ @SuppressWarnings("unchecked") public void saveTrainingData(File file) throws Exception { Logger.getLogger(TopicIndexer.class).info("saving training data") ; dataset.save(file) ; } /** * Loads training data from the given file. * The file must be a valid WEKA arff file. * * @param file the file to save the training data to * @throws IOException if the file cannot be read. * @throws Exception if the file does not contain valid training data. */ public void loadTrainingData(File file) throws Exception{ Logger.getLogger(TopicIndexer.class).info("loading training data") ; dataset.load(file) ; weightTrainingInstances() ; } public void clearTrainingData() { dataset = null ; } /** * Serializes the classifer and saves it to the given file. * * @param file the file to save the classifier to * @throws IOException if the file cannot be read */ public void saveClassifier(File file) throws IOException { Logger.getLogger(TopicIndexer.class).info("saving classifier") ; decider.save(file) ; } /** * Loads the classifier from file * * @param file * @throws Exception */ public void loadClassifier(File file) throws Exception { Logger.getLogger(TopicIndexer.class).info("loading classifier") ; decider.load(file) ; } /** * * * @param classifier * @throws Exception */ public void buildClassifier(Classifier classifier) throws Exception { Logger.getLogger(TopicIndexer.class).info("building classifier") ; decider.train(classifier, dataset) ; } public void buildDefaultClassifier() throws Exception { Logger.getLogger(TopicIndexer.class).info("building classifier") ; Classifier classifier = new Bagging() ; classifier.setOptions(Utils.splitOptions("-P 10 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -U -M 2")) ; decider.train(classifier, dataset) ; } private void train(TopicIndexingSet.Item item, TopicDetector td) throws Exception{ RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ; Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ; for (Topic topic: topics) dataset.add(getInstance(topic, item.isTopic(topic))) ; } private Result<Integer> test(TopicIndexingSet.Item item, TopicDetector td) throws Exception{ RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ; Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ; ArrayList<Topic> weightedTopics = this.getWeightedTopics(topics) ; HashSet<Integer> autoIds = new HashSet<Integer>() ; for (Topic topic: weightedTopics) { if (topic.getWeight() > 0.5) autoIds.add(topic.getId()) ; } Result<Integer> result = new Result<Integer>(autoIds,item.getTopicIds()) ; System.out.println(" - " + result) ; return result ; } private Instance getInstance(Topic topic, Boolean isKeyTopic) throws Exception { InstanceBuilder<Attributes,Boolean> ib = decider.getInstanceBuilder() .setAttribute(Attributes.occurances, topic.getOccurances()) .setAttribute(Attributes.maxDisambigConfidence, topic.getMaxDisambigConfidence()) .setAttribute(Attributes.avgDisambigConfidence, topic.getAverageDisambigConfidence()) .setAttribute(Attributes.relatednessToContext, topic.getRelatednessToContext()) .setAttribute(Attributes.relatednessToOtherTopics, topic.getRelatednessToOtherTopics()) .setAttribute(Attributes.maxLinkProbability, topic.getMaxLinkProbability()) .setAttribute(Attributes.avgLinkProbability, topic.getAverageLinkProbability()) .setAttribute(Attributes.generality, topic.getGenerality()) .setAttribute(Attributes.firstOccurance, topic.getFirstOccurance()) .setAttribute(Attributes.lastOccurance, topic.getLastOccurance()) .setAttribute(Attributes.spread, topic.getSpread()) ; if (isKeyTopic != null) ib = ib.setClassAttribute(isKeyTopic) ; return ib.build() ; } //TODO: this should really be refactored as a separate filter @SuppressWarnings("unchecked") private void weightTrainingInstances() { double positiveInstances = 0 ; double negativeInstances = 0 ; Enumeration<Instance> e = dataset.enumerateInstances() ; while (e.hasMoreElements()) { Instance i = (Instance) e.nextElement() ; double isValidSense = i.value(3) ; if (isValidSense == 0) positiveInstances ++ ; else negativeInstances ++ ; } double p = (double) positiveInstances / (positiveInstances + negativeInstances) ; e = dataset.enumerateInstances() ; while (e.hasMoreElements()) { Instance i = (Instance) e.nextElement() ; double isValidSense = i.value(3) ; if (isValidSense == 0) i.setWeight(0.5 * (1.0/p)) ; else i.setWeight(0.5 * (1.0/(1-p))) ; } } }