package org.wikipedia.miner.examples; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.text.DecimalFormat; import org.wikipedia.miner.comparison.ArticleComparer; import org.wikipedia.miner.comparison.ComparisonDataSet; import org.wikipedia.miner.comparison.LabelComparer; import org.wikipedia.miner.db.WDatabase.DatabaseType; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.WikipediaConfiguration; import weka.classifiers.Classifier; import weka.core.Utils; public class ComparisonWorkbench { private Wikipedia _wikipedia ; //directory in which files will be stored private File _dataDir ; private File _datasetFile ; private ComparisonDataSet _dataset ; //classes for performing comparison private ArticleComparer _artComparer ; private LabelComparer _labelComparer ; private File _arffArtCompare, _arffLabelDisambig, _arffLabelCompare ; private File _modelArtCompare, _modelLabelDisambig, _modelLabelCompare ; DecimalFormat df = new DecimalFormat("0.0000") ; public ComparisonWorkbench(File dataDir, ComparisonDataSet dataset, Wikipedia wikipedia) throws Exception { _dataDir = dataDir ; _wikipedia = wikipedia ; _dataset = dataset ; _artComparer = new ArticleComparer(_wikipedia) ; _labelComparer = new LabelComparer(_wikipedia, _artComparer) ; _arffArtCompare = new File(_dataDir.getPath() + "/art_compare.arff") ; _arffLabelDisambig = new File(_dataDir.getPath() + "/lbl_disambig.arff") ; _arffLabelCompare = new File(_dataDir.getPath() + "/lbl_compare.arff") ; _modelArtCompare = new File(_dataDir.getPath() + "/art_compare.model") ; _modelLabelDisambig = new File(_dataDir.getPath() + "/lbl_disambig.model") ; _modelLabelCompare = new File(_dataDir.getPath() + "/lbl_compare.model") ; } private void createArffFiles(String datasetName) throws IOException, Exception { _artComparer.train(_dataset); _artComparer.buildDefaultClassifier(); _artComparer.saveTrainingData(_arffArtCompare); _labelComparer.train(_dataset, datasetName); _labelComparer.buildDefaultClassifiers(); _labelComparer.saveDisambiguationTrainingData(_arffLabelDisambig); _labelComparer.saveComparisonTrainingData(_arffLabelCompare); } private void createClassifiers(String confArtCompare, String confLabelDisambig, String confLabelCompare) throws Exception { if (!_arffArtCompare.canRead() || !_arffLabelDisambig.canRead() || !_arffLabelCompare.canRead()) throw new Exception("Arff files have not yet been created") ; if (confArtCompare == null || confArtCompare.trim().length() == 0) { _artComparer.buildDefaultClassifier() ; } else { Classifier classifier = buildClassifierFromOptString(confArtCompare) ; _artComparer.buildClassifier(classifier) ; } _artComparer.saveClassifier(_modelArtCompare) ; if (confLabelDisambig == null || confLabelDisambig.trim().length() == 0) { _labelComparer.buildDefaultClassifiers() ; } else { Classifier classifierLabelDisambig = buildClassifierFromOptString(confLabelDisambig) ; Classifier classifierLabelCompare = buildClassifierFromOptString(confLabelCompare) ; //TODO: need to use provided classifiers _labelComparer.buildDefaultClassifiers(); } _labelComparer.saveDisambiguationClassifier(_modelLabelDisambig); _labelComparer.saveComparisonClassifier(_modelLabelCompare); } private Classifier buildClassifierFromOptString(String optString) throws Exception { String[] options = Utils.splitOptions(optString) ; String classname = options[0] ; options[0] = "" ; return (Classifier) Utils.forName(Classifier.class, classname, options) ; } private void evaluate() throws Exception { ComparisonDataSet[][] folds = _dataset.getFolds() ; double totalArtCompare = 0 ; double totalLabelDisambig = 0 ; double totalLabelCompare = 0 ; int foldIndex = 0; for (ComparisonDataSet[] fold:folds) { System.out.println("Fold " + foldIndex) ; foldIndex++ ; ComparisonDataSet trainingData = fold[0] ; ComparisonDataSet testData = fold[1] ; ArticleComparer artComparer = new ArticleComparer(_wikipedia) ; artComparer.train(trainingData); artComparer.buildDefaultClassifier(); double corrArtCompare = artComparer.test(testData) ; System.out.println(" - art comparison: " + df.format(corrArtCompare)); totalArtCompare += corrArtCompare ; LabelComparer lblComparer = new LabelComparer(_wikipedia, artComparer) ; lblComparer.train(trainingData, ""); lblComparer.buildDefaultClassifiers(); double accLabelDisambig = lblComparer.testDisambiguationAccuracy(testData) ; System.out.println(" - label disambiguation: " + df.format(accLabelDisambig)); totalLabelDisambig += accLabelDisambig ; double corrLabelCompare = lblComparer.testRelatednessPrediction(testData) ; System.out.println(" - label comparison: " + df.format(corrLabelCompare)); totalLabelCompare += corrLabelCompare ; } System.out.println(); System.out.println("art comparison (correllation); " + df.format(totalArtCompare/folds.length)) ; System.out.println("label disambiguation (accuracy); " + df.format(totalLabelDisambig/folds.length)) ; System.out.println("label comparison (correllation); " + df.format(totalLabelCompare/folds.length)) ; } public static void main(String args[]) throws Exception { File dataDir = new File(args[0]) ; File datasetFile = new File(args[1]) ; int maxRelatedness = Integer.parseInt(args[2]) ; ComparisonDataSet dataset = new ComparisonDataSet(datasetFile, maxRelatedness) ; WikipediaConfiguration conf = new WikipediaConfiguration(new File(args[3])) ; conf.addDatabaseToCache(DatabaseType.label) ; conf.addDatabaseToCache(DatabaseType.pageLinksInNoSentences) ; Wikipedia wikipedia = new Wikipedia(conf, false) ; ComparisonWorkbench trainer = new ComparisonWorkbench(dataDir, dataset, wikipedia) ; BufferedReader input = new BufferedReader(new InputStreamReader(System.in)) ; while (true) { System.out.println("What would you like to do?") ; System.out.println(" - [1] create arff files.") ; System.out.println(" - [2] create classifiers.") ; System.out.println(" - [3] evaluate classifiers.") ; System.out.println(" - or ENTER to quit.") ; String line = input.readLine() ; if (line.trim().length() == 0) break ; Integer choice = 0 ; try { choice = Integer.parseInt(line) ; } catch (Exception e) { System.out.println("Invalid Input") ; continue ; } switch(choice) { case 1: System.out.println("Dataset name:") ; String datasetName = input.readLine() ; trainer.createArffFiles(datasetName) ; break ; case 2: System.out.println("Article comparison classifer config (or ENTER to use default):") ; String confArtCompare = input.readLine() ; System.out.println("Label disambiguation classifer config (or ENTER to use default):") ; String confLabelDisambig = input.readLine() ; System.out.println("Label comparison classifer config (or ENTER to use default):") ; String confLabelDetect = input.readLine() ; trainer.createClassifiers(confArtCompare, confLabelDisambig, confLabelDetect) ; break ; case 3: trainer.evaluate() ; break ; default: System.out.println("Invalid Input") ; } } } }