package org.wikipedia.miner.comparison;
import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.TreeSet;
import org.wikipedia.miner.annotation.Disambiguator;
import org.wikipedia.miner.model.Label;
import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.CorrelationCalculator;
import org.wikipedia.miner.util.ProgressTracker;
import org.dmilne.weka.wrapper.*;
import weka.classifiers.Classifier;
import weka.classifiers.functions.GaussianProcesses;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.Bagging;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Utils;
public class LabelComparer {
private Wikipedia wikipedia ;
private ArticleComparer articleComparer ;
private enum SenseAttr {
predictedRelatedness, distanceFromBenchmarkRelatedness, distanceFromTopRelatedness, distanceFromTopPriorProbability, avgPriorProbability, maxPriorProbability, minPriorProbability, avgGenerality, maxGenerality, minGenerality
}
private enum RelatednessAttr {
bestSenseRelatedness, maxSenseRelatedness, avgSenseRelatedness, weightedAvgSenseRelatedness, concatenationPriorLinkProbability, concatenationOccurances
}
private Decider<SenseAttr,Boolean> senseSelector ;
private Dataset<SenseAttr,Boolean> senseDataset ;
private Decider<RelatednessAttr, Double> relatednessMeasurer ;
private Dataset<RelatednessAttr,Double> relatednessDataset ;
DecimalFormat df = new DecimalFormat("0.00") ;
public LabelComparer(Wikipedia wikipedia, ArticleComparer articleComparer) throws Exception {
this.wikipedia = wikipedia ;
this.articleComparer = articleComparer ;
senseSelector = (Decider<SenseAttr, Boolean>) new DeciderBuilder<SenseAttr>("labelSenseSelector", SenseAttr.class)
.setDefaultAttributeTypeNumeric()
.setClassAttributeTypeBoolean("isValid")
.build();
relatednessMeasurer = (Decider<RelatednessAttr, Double>) new DeciderBuilder<RelatednessAttr>("labelRelatednessMeasurer", RelatednessAttr.class)
.setDefaultAttributeTypeNumeric()
.setClassAttributeTypeNumeric("relatedness")
.build();
if (wikipedia.getConfig().getLabelDisambiguationModel() != null) {
loadDisambiguationClassifier(wikipedia.getConfig().getLabelDisambiguationModel()) ;
}
if (wikipedia.getConfig().getLabelComparisonModel() != null)
loadComparisonClassifier(wikipedia.getConfig().getLabelComparisonModel()) ;
}
public ComparisonDetails compare(Label labelA, Label labelB) throws Exception {
if (!senseSelector.isReady())
throw new Exception("You must train+build a new label sense selection classifier or load an existing one first") ;
if (!relatednessMeasurer.isReady())
throw new Exception("You must train+build a new label relatedness measuring classifier or load and existing one first") ;
return new ComparisonDetails(labelA, labelB) ;
}
/**
* A convenience function to compare labels without returning details.
*
* @param labelA
* @param labelB
* @return the semantic relatedness between the two labels
* @throws Exception
*/
public Double getRelatedness(Label labelA, Label labelB) throws Exception {
ComparisonDetails cmp = compare(labelA, labelB) ;
return cmp.getLabelRelatedness() ;
}
public void train(ComparisonDataSet dataset, String datasetName) throws Exception {
senseDataset = senseSelector.createNewDataset() ;
relatednessDataset = relatednessMeasurer.createNewDataset() ;
ProgressTracker pn = new ProgressTracker(dataset.getItems().size(), "training", LabelComparer.class) ;
for (ComparisonDataSet.Item item: dataset.getItems()) {
train(item) ;
pn.update() ;
}
//TODO: filter to resolve skewness?
}
public void saveDisambiguationTrainingData(File file) throws IOException, Exception {
senseDataset.save(file) ;
}
public void saveComparisonTrainingData(File file) throws IOException, Exception {
relatednessDataset.save(file) ;
}
public double testRelatednessPrediction(ComparisonDataSet dataset) throws Exception {
ArrayList<Double> manualMeasures = new ArrayList<Double>() ;
ArrayList<Double> autoMeasures = new ArrayList<Double>() ;
ProgressTracker pt = new ProgressTracker(dataset.getItems().size(), "testing relatedness prediction", LabelComparer.class) ;
for (ComparisonDataSet.Item item: dataset.getItems()) {
Label labelA = new Label(wikipedia.getEnvironment(), item.getTermA()) ;
Label labelB = new Label(wikipedia.getEnvironment(), item.getTermB()) ;
if (!labelA.exists())
continue ;
if (!labelB.exists())
continue ;
Double manual = item.getRelatedness() ;
Double auto = getRelatedness(labelA, labelB) ;
if (auto != null) {
manualMeasures.add(manual) ;
autoMeasures.add(auto) ;
}
pt.update() ;
}
return CorrelationCalculator.getCorrelation(manualMeasures, autoMeasures) ;
}
public Double testDisambiguationAccuracy(ComparisonDataSet dataset) throws Exception {
int totalInterpretations = 0 ;
int correctInterpretations = 0 ;
ProgressTracker pt = new ProgressTracker(dataset.getItems().size(), "testing disambiguation accuracy", LabelComparer.class) ;
for (ComparisonDataSet.Item item: dataset.getItems()) {
if (item.getIdA() < 0 || item.getIdB() < 0)
continue ;
totalInterpretations++ ;
Label labelA = new Label(wikipedia.getEnvironment(), item.getTermA()) ;
Label labelB = new Label(wikipedia.getEnvironment(), item.getTermB()) ;
ComparisonDetails details = this.compare(labelA, labelB) ;
SensePair sp = details.getBestInterpretation() ;
if (sp != null) {
if (sp.getSenseA().getId() == item.getIdA() && sp.getSenseB().getId() == item.getIdB())
correctInterpretations ++ ;
}
pt.update();
}
if (totalInterpretations > 0)
return (double) correctInterpretations/totalInterpretations ;
else
return null ;
}
public void loadDisambiguationClassifier(File file) throws Exception {
senseSelector.load(file) ;
}
public void loadComparisonClassifier(File file) throws Exception {
relatednessMeasurer.load(file) ;
}
public void saveDisambiguationClassifier(File file) throws Exception {
senseSelector.save(file) ;
}
public void saveComparisonClassifier(File file) throws Exception {
relatednessMeasurer.save(file) ;
}
//TODO: saving and loading arff files?
public void buildDefaultClassifiers() throws Exception {
Classifier ssClassifier = new Bagging() ;
ssClassifier.setOptions(Utils.splitOptions("-P 10 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -U -M 2")) ;
senseSelector.train(ssClassifier, senseDataset) ;
Classifier rmClassifier = new GaussianProcesses() ;
relatednessMeasurer.train(rmClassifier, relatednessDataset) ;
}
private void train(ComparisonDataSet.Item item) throws Exception {
Label labelA = new Label(wikipedia.getEnvironment(), item.getTermA()) ;
Label labelB = new Label(wikipedia.getEnvironment(), item.getTermB()) ;
if (!labelA.exists())
return ;
if (!labelB.exists())
return ;
new ComparisonDetails(labelA, labelB, item.getIdA(), item.getIdB(), item.getRelatedness()) ;
}
public class ComparisonDetails {
private Label labelA ;
private Label labelB ;
private Label concatenation ;
private Double labelRelatedness ;
private ArrayList<SensePair> candidateInterpretations = new ArrayList<SensePair>() ;
private double maxSpRelatedness ;
private double avgSpRelatedness ;
private double weightedAvgSpRelatedness ;
public Label getLabelA() {
return labelA;
}
public Label getLabelB() {
return labelB;
}
public Double getLabelRelatedness() {
return labelRelatedness;
}
public ArrayList<SensePair> getCandidateInterpretations() {
return candidateInterpretations;
}
public SensePair getBestInterpretation() {
if (candidateInterpretations.size() > 0)
return candidateInterpretations.get(0) ;
else
return null ;
}
/**
* Constructs details for item where correct disambiguation and relatedness are already known (training)
*
* @param labelA
* @param labelB
* @throws Exception
* @throws Exception
*/
private ComparisonDetails(Label labelA, Label labelB, int artA, int artB, double relatedness) throws Exception {
init(labelA, labelB, artA, artB, relatedness) ;
}
private ComparisonDetails(Label labelA, Label labelB) throws Exception {
init(labelA, labelB, null, null, null) ;
}
private void init(Label labelA, Label labelB, Integer senseIdA, Integer senseIdB, Double relatedness) throws Exception {
this.labelA = labelA ;
this.labelB = labelB ;
concatenation = new Label(wikipedia.getEnvironment(), labelA.getText() + " " + labelB.getText()) ;
double benchmarkRelatedness = 0 ;
double spacer = 0.5 ;
double topPriorProbability = 0 ;
double topRelatedness = 0 ;
for (Label.Sense senseA:labelA.getSenses()) {
if (senseA.getPriorProbability() < wikipedia.getConfig().getMinSenseProbability())
break ;
for (Label.Sense senseB:labelB.getSenses()) {
if (senseB.getPriorProbability() < wikipedia.getConfig().getMinSenseProbability())
break ;
SensePair sp = new SensePair(senseA, senseB) ;
if (sp.getSenseRelatedness() > benchmarkRelatedness + (benchmarkRelatedness*spacer)) {
//this sets a new benchmark
benchmarkRelatedness = sp.getSenseRelatedness() ;
candidateInterpretations.clear();
candidateInterpretations.add(sp);
topPriorProbability = sp.avgPriorProbability ;
topRelatedness = sp.senseRelatedness ;
} else if (sp.getSenseRelatedness() > benchmarkRelatedness - (benchmarkRelatedness*spacer)) {
//this is close enough to benchmark to be considered
candidateInterpretations.add(sp);
if (sp.avgPriorProbability > topPriorProbability)
topPriorProbability = sp.avgPriorProbability ;
if (sp.senseRelatedness > topRelatedness)
topRelatedness = sp.senseRelatedness ;
}
}
}
double totalSpRelatedness = 0 ;
double totalWeightedSpRelatedness = 0 ;
double totalWeight = 0 ;
int spCount = 0 ;
for (SensePair sp:candidateInterpretations) {
sp.setDistanceFromBenchmarkRelatedness(benchmarkRelatedness-sp.getSenseRelatedness()) ;
sp.setDistanceFromTopRelatedness(topRelatedness-sp.getSenseRelatedness()) ;
sp.setDistanceFromTopPriorProbability(topPriorProbability-sp.avgPriorProbability) ;
if (senseIdA != null && senseIdB != null) {
//this is a training instance, where correct senses are known
if (senseIdA > 0 && senseIdB >0) {
if (sp.getSenseA().getId() == senseIdA && sp.getSenseB().getId() == senseIdB) {
sp.setIsValid(true) ;
} else {
sp.setIsValid(false) ;
}
Instance i = sp.getInstance() ;
//weighting training instances
// - training instances for terms that aren't closely related to each other don't matter much
// - because it doesn't really matter how they are interpreted
double weight = sp.getSenseRelatedness() ;
// - negative instances that are close to the relatedness of the correct interpretation shouldn't matter much either
if (!sp.isValid) {
double distToActualRelatedness = Math.abs(relatedness-sp.getSenseRelatedness()) ;
weight = (weight + distToActualRelatedness) / 2 ;
}
i.setWeight(weight) ;
senseDataset.add(i) ;
}
} else {
//correct senses must be predicted
sp.predictIsValid() ;
}
if (sp.getSenseRelatedness() > maxSpRelatedness)
maxSpRelatedness = sp.getSenseRelatedness() ;
//System.out.println(" - " + sp) ;
totalSpRelatedness += sp.getSenseRelatedness() ;
totalWeightedSpRelatedness += (sp.avgPriorProbability * sp.getSenseRelatedness()) ;
totalWeight += sp.avgPriorProbability ;
spCount++ ;
}
Collections.sort(candidateInterpretations) ;
if (spCount > 0) {
avgSpRelatedness = totalSpRelatedness/spCount ;
weightedAvgSpRelatedness = totalWeightedSpRelatedness/totalWeight ;
} else {
avgSpRelatedness = 0 ;
weightedAvgSpRelatedness = 0 ;
}
if (relatedness != null) {
//this is a training instance, where relatedness is known
labelRelatedness = relatedness ;
relatednessDataset.add(getInstance()) ;
} else {
//relatedness must be predicted
labelRelatedness = relatednessMeasurer.getDecision(getInstance()) ;
}
//System.out.println(this) ;
}
@Override
public String toString() {
try {
Instance i = getInstance() ;
StringBuffer sb = new StringBuffer() ;
sb.append(labelA.getText()) ;
sb.append(" vs. ") ;
sb.append(labelB.getText()) ;
sb.append(" br: " + df.format(i.value(RelatednessAttr.bestSenseRelatedness.ordinal()))) ;
sb.append(" wr: " + df.format(i.value(RelatednessAttr.weightedAvgSenseRelatedness.ordinal()))) ;
sb.append(" cpp: " + df.format(i.value(RelatednessAttr.concatenationPriorLinkProbability.ordinal()))) ;
return sb.toString() ;
} catch (Exception e) {
return e.getMessage() ;
}
}
protected Instance getInstance() throws ClassMissingException, AttributeMissingException {
InstanceBuilder<RelatednessAttr,Double> ib = relatednessMeasurer.getInstanceBuilder() ;
if (candidateInterpretations.size() > 0)
ib.setAttribute(RelatednessAttr.bestSenseRelatedness, candidateInterpretations.get(0).getSenseRelatedness()) ;
else
ib.setAttribute(RelatednessAttr.bestSenseRelatedness, Instance.missingValue()) ;
ib.setAttribute(RelatednessAttr.maxSenseRelatedness, maxSpRelatedness) ;
ib.setAttribute(RelatednessAttr.avgSenseRelatedness, avgSpRelatedness) ;
ib.setAttribute(RelatednessAttr.weightedAvgSenseRelatedness, weightedAvgSpRelatedness) ;
ib.setAttribute(RelatednessAttr.concatenationPriorLinkProbability, concatenation.getLinkProbability()) ;
ib.setAttribute(RelatednessAttr.concatenationOccurances, Math.log(concatenation.getOccCount()+1)) ;
if (labelRelatedness != null)
ib.setClassAttribute(labelRelatedness) ;
return ib.build() ;
}
}
public class SensePair implements Comparable<SensePair> {
private Label.Sense senseA ;
private Label.Sense senseB ;
private Double avgPriorProbability ;
private Double maxPriorProbability ;
private Double minPriorProbability ;
private Double avgGenerality ;
private Double maxGenerality ;
private Double minGenerality ;
private Double distanceFromBenchmarkRelatedness ;
private Double distanceFromTopRelatedness ;
private Double distanceFromTopPriorProbability ;
private Double senseRelatedness ;
private Boolean isValid = null ;
private Double disambiguationConfidence = null ;
private SensePair(Label.Sense senseA, Label.Sense senseB) throws Exception {
init(senseA, senseB) ;
}
private void setIsValid(boolean valid) {
isValid = valid ;
if (isValid)
disambiguationConfidence = 1.0 ;
else
disambiguationConfidence = 0.0 ;
}
private void predictIsValid() throws ClassMissingException, AttributeMissingException, Exception {
disambiguationConfidence = senseSelector.getDecisionDistribution(getInstance()).get(true) ;
isValid = (disambiguationConfidence > 0.5) ;
}
private void setDistanceFromBenchmarkRelatedness(double distance) {
distanceFromBenchmarkRelatedness = distance ;
}
private void setDistanceFromTopRelatedness(double distance) {
distanceFromTopRelatedness = distance ;
}
private void setDistanceFromTopPriorProbability(double distance) {
distanceFromTopPriorProbability = distance ;
}
private void init(Label.Sense senseA, Label.Sense senseB) throws Exception {
this.senseA = senseA ;
this.senseB = senseB ;
maxPriorProbability = Math.max(senseA.getPriorProbability(), senseB.getPriorProbability()) ;
minPriorProbability = Math.min(senseA.getPriorProbability(), senseB.getPriorProbability()) ;
avgPriorProbability = (maxPriorProbability+minPriorProbability)/2 ;
if (senseA.getGenerality() != null && senseB.getGenerality() != null) {
maxGenerality = (double)Math.max(senseA.getGenerality(), senseB.getGenerality()) ;
minGenerality = (double)Math.min(senseA.getGenerality(), senseB.getGenerality()) ;
avgGenerality = (maxGenerality+minGenerality)/2 ;
}
senseRelatedness = articleComparer.getRelatedness(senseA, senseB) ;
}
public Double getDisambiguationConfidence() {
return disambiguationConfidence ;
}
public Label.Sense getSenseA() {
return senseA;
}
public Label.Sense getSenseB() {
return senseB;
}
public Double getSenseRelatedness() {
return senseRelatedness;
}
protected Instance getInstance() throws ClassMissingException, AttributeMissingException {
InstanceBuilder<SenseAttr,Boolean> ib = senseSelector.getInstanceBuilder() ;
ib.setAttribute(SenseAttr.predictedRelatedness, senseRelatedness) ;
ib.setAttribute(SenseAttr.avgPriorProbability, avgPriorProbability) ;
ib.setAttribute(SenseAttr.maxPriorProbability, maxPriorProbability) ;
ib.setAttribute(SenseAttr.minPriorProbability, minPriorProbability) ;
ib.setAttribute(SenseAttr.avgGenerality, avgGenerality) ;
ib.setAttribute(SenseAttr.maxGenerality, maxGenerality) ;
ib.setAttribute(SenseAttr.minGenerality, minGenerality) ;
ib.setAttribute(SenseAttr.distanceFromBenchmarkRelatedness, distanceFromBenchmarkRelatedness) ;
ib.setAttribute(SenseAttr.distanceFromTopRelatedness, distanceFromTopRelatedness) ;
ib.setAttribute(SenseAttr.distanceFromTopPriorProbability, distanceFromTopPriorProbability) ;
if (disambiguationConfidence != null)
ib.setClassAttribute(isValid) ;
return ib.build() ;
}
public int compareTo(SensePair sp) {
int cmp = 0 ;
if (disambiguationConfidence != null && sp.disambiguationConfidence != null) {
cmp = sp.disambiguationConfidence.compareTo(disambiguationConfidence) ;
if (cmp != 0)
return cmp ;
}
cmp = sp.avgPriorProbability.compareTo(avgPriorProbability) ;
if (cmp != 0)
return cmp ;
cmp = senseA.compareTo(sp.senseA) ;
if (cmp != 0)
return cmp ;
cmp = senseB.compareTo(sp.senseB) ;
return cmp ;
}
@Override
public String toString() {
DecimalFormat df = new DecimalFormat("0.00") ;
StringBuffer sb = new StringBuffer() ;
sb.append(senseA) ;
sb.append(" vs. ") ;
sb.append(senseB) ;
if (disambiguationConfidence != null)
sb.append(" dc:" + df.format(disambiguationConfidence)) ;
else
sb.append(" dc:null") ;
sb.append(" r:" + df.format(senseRelatedness)) ;
sb.append(" pp:" + df.format(avgPriorProbability)) ;
sb.append(" distR:" + df.format(this.distanceFromTopRelatedness)) ;
sb.append(" distPP: " + df.format(this.distanceFromTopPriorProbability)) ;
return sb.toString() ;
}
}
}