/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ICDM08EnsembleOfPrunedSets.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.experiments;
/**
* @author Emmanouela Stachtiari
* @author Grigorios Tsoumakas
* @version 2010.12.10
*/
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Random;
import mulan.classifier.meta.thresholding.OneThreshold;
import mulan.classifier.transformation.EnsembleOfPrunedSets;
import mulan.classifier.transformation.PrunedSets;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.Evaluator;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.ExampleBasedAccuracy;
import mulan.evaluation.measure.ExampleBasedFMeasure;
import mulan.evaluation.measure.HammingLoss;
import mulan.evaluation.measure.Measure;
import weka.classifiers.functions.SMO;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
/**
* Class replicating an experiment from a published paper
*
* @author Grigorios Tsoumakas
* @version 2010.12.27
*/
public class ICDM08EnsembleOfPrunedSets extends Experiment {
/**
* Main class
*
* @param args command line arguments
*/
public static void main(String[] args) {
try {
String path = Utils.getOption("path", args);
String filestem = Utils.getOption("filestem", args);
System.out.println("Loading the data set");
MultiLabelInstances dataSet = new MultiLabelInstances(path + filestem + ".arff", path + filestem + ".xml");
Evaluator evaluator;
Measure[] evaluationMeasures = new Measure[2];
evaluationMeasures[0] = new ExampleBasedAccuracy(false);
evaluationMeasures[1] = new HammingLoss();
evaluationMeasures[2] = new ExampleBasedFMeasure(false);
HashMap<String, MultipleEvaluation> result = new HashMap<String, MultipleEvaluation>();
for (Measure m : evaluationMeasures) {
MultipleEvaluation me = new MultipleEvaluation();
result.put(m.getName(), me);
}
Random random = new Random(1);
for (int repetition = 0; repetition < 5; repetition++) {
// perform 2-fold CV and add each to the current results
dataSet.getDataSet().randomize(random);
for (int fold = 0; fold < 2; fold++) {
System.out.println("Experiment " + (repetition * 2 + fold + 1));
Instances train = dataSet.getDataSet().trainCV(2, fold);
MultiLabelInstances multiTrain = new MultiLabelInstances(train, dataSet.getLabelsMetaData());
Instances test = dataSet.getDataSet().testCV(2, fold);
MultiLabelInstances multiTest = new MultiLabelInstances(test, dataSet.getLabelsMetaData());
HashMap<String, Integer> bestP = new HashMap<String, Integer>();
HashMap<String, Integer> bestB = new HashMap<String, Integer>();
HashMap<String, PrunedSets.Strategy> bestStrategy = new HashMap<String, PrunedSets.Strategy>();
HashMap<String, Double> bestDiff = new HashMap<String, Double>();
for (Measure m : evaluationMeasures) {
bestDiff.put(m.getName(), Double.MAX_VALUE);
}
System.out.println("Searching parameters");
for (int p = 5; p > 1; p--) {
for (int b = 1; b < 4; b++) {
MultipleEvaluation innerResult = null;
LinkedList<Measure> measures;
PrunedSets ps;
double diff;
evaluator = new Evaluator();
ps = new PrunedSets(new SMO(), p, PrunedSets.Strategy.A, b);
measures = new LinkedList<Measure>();
for (Measure m : evaluationMeasures) {
measures.add(m.makeCopy());
}
System.out.print("p=" + p + " b=" + b + " strategy=A ");
innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
for (Measure m : evaluationMeasures) {
System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
if (diff <= bestDiff.get(m.getName())) {
bestDiff.put(m.getName(), diff);
bestP.put(m.getName(), p);
bestB.put(m.getName(), b);
bestStrategy.put(m.getName(), PrunedSets.Strategy.A);
}
}
System.out.println();
evaluator = new Evaluator();
ps = new PrunedSets(new SMO(), p, PrunedSets.Strategy.B, b);
measures = new LinkedList<Measure>();
for (Measure m : evaluationMeasures) {
measures.add(m.makeCopy());
}
System.out.print("p=" + p + " b=" + b + " strategy=B ");
innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
for (Measure m : evaluationMeasures) {
System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
if (diff <= bestDiff.get(m.getName())) {
bestDiff.put(m.getName(), diff);
bestP.put(m.getName(), p);
bestB.put(m.getName(), b);
bestStrategy.put(m.getName(), PrunedSets.Strategy.B);
}
}
System.out.println();
}
}
for (Measure m : evaluationMeasures) {
System.out.println(m.getName());
System.out.println("Best p: " + bestP.get(m.getName()));
System.out.println("Best strategy: " + bestStrategy.get(m.getName()));
System.out.println("Best b: " + bestB.get(m.getName()));
EnsembleOfPrunedSets eps = new EnsembleOfPrunedSets(63, 10, 0.5, bestP.get(m.getName()), bestStrategy.get(m.getName()), bestB.get(m.getName()), new SMO());
OneThreshold ot = new OneThreshold(eps, (BipartitionMeasureBase) m.makeCopy(), 5);
ot.build(multiTrain);
System.out.println("Best threshold: " + ot.getThreshold());
evaluator = new Evaluator();
Evaluation e = evaluator.evaluate(ot, multiTest);
System.out.println(e.toCSV());
result.get(m.getName()).addEvaluation(e);
}
}
}
for (Measure m : evaluationMeasures) {
System.out.println(m.getName());
result.get(m.getName()).calculateStatistics();
System.out.println(result.get(m.getName()));
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result = new TechnicalInformation(Type.CONFERENCE);
result.setValue(Field.AUTHOR, "Read, Jesse");
result.setValue(Field.TITLE, "Multi-label Classification using Ensembles of Pruned Sets");
result.setValue(Field.PAGES, "995-1000");
result.setValue(Field.BOOKTITLE, "ICDM'08: Eighth IEEE International Conference on Data Mining");
result.setValue(Field.YEAR, "2008");
return result;
}
}