/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* Evaluator.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.evaluation;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.*;
import weka.core.Instance;
import weka.core.Instances;
/**
* Evaluator - responsible for generating evaluation data
* @author rofr
* @author Grigorios Tsoumakas
* @version 2010.11.06
*/
public class Evaluator {
// seed for reproduction of cross-validation results
private int seed = 1;
// when false divisions-by-zero are ignored in certain measures
private boolean strict = true;
/**
* Sets the seed for reproduction of cross-validation results
*
* @param aSeed seed for reproduction of cross-validation results
*/
public void setSeed(int aSeed) {
seed = aSeed;
}
/**
* Controls how divisions-by-zero are handled
*
* @param isStrict when false divisions-by-zero are ignored
*/
public void setStrict(boolean isStrict) {
strict = isStrict;
}
/**
* Evaluates a {@link MultiLabelLearner} on given test data set using specified evaluation measures
*
* @param learner the learner to be evaluated via cross-validation
* @param data the data set for cross-validation
* @param measures the evaluation measures to compute
* @return an Evaluation object
* @throws IllegalArgumentException if an input parameter is null
* @throws Exception
*/
public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures) throws IllegalArgumentException, Exception {
checkLearner(learner);
checkData(data);
checkMeasures(measures);
// reset measures
for (Measure m : measures) {
m.reset();
}
int numLabels = data.getNumLabels();
int[] labelIndices = data.getLabelIndices();
boolean[] trueLabels = new boolean[numLabels];
Set<Measure> failed = new HashSet<Measure>();
Instances testData = data.getDataSet();
int numInstances = testData.numInstances();
for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) {
Instance instance = testData.instance(instanceIndex);
if (data.hasMissingLabels(instance)) {
continue;
}
MultiLabelOutput output = learner.makePrediction(instance);
trueLabels = getTrueLabels(instance, numLabels, labelIndices);
Iterator<Measure> it = measures.iterator();
while (it.hasNext()) {
Measure m = it.next();
if (!failed.contains(m)) {
try {
m.update(output, trueLabels);
} catch (Exception ex) {
failed.add(m);
}
}
}
}
return new Evaluation(measures);
}
private void checkLearner(MultiLabelLearner learner) {
if (learner == null) {
throw new IllegalArgumentException("Learner to be evaluated is null.");
}
}
private void checkData(MultiLabelInstances data) {
if (data == null) {
throw new IllegalArgumentException("Evaluation data object is null.");
}
}
private void checkMeasures(List<Measure> measures) {
if (measures == null) {
throw new IllegalArgumentException("List of evaluation measures to compute is null.");
}
}
private void checkFolds(int someFolds) {
if (someFolds < 2) {
throw new IllegalArgumentException("Number of folds must be at least two or higher.");
}
}
/**
* Evaluates a {@link MultiLabelLearner} on given test data set.
*
* @param learner the learner to be evaluated
* @param data the data set for evaluation
* @return the evaluation result
* @throws IllegalArgumentException if either of input parameters is null.
* @throws Exception
*/
public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data) throws IllegalArgumentException, Exception {
checkLearner(learner);
checkData(data);
List<Measure> measures = prepareMeasures(learner, data);
return evaluate(learner, data, measures);
}
private List<Measure> prepareMeasures(MultiLabelLearner learner, MultiLabelInstances data) {
List<Measure> measures = new ArrayList<Measure>();
MultiLabelOutput prediction;
try {
MultiLabelLearner copyOfLearner = learner.makeCopy();
prediction = copyOfLearner.makePrediction(data.getDataSet().instance(0));
// add bipartition-based measures if applicable
if (prediction.hasBipartition()) {
// add example-based measures
measures.add(new HammingLoss());
measures.add(new SubsetAccuracy());
measures.add(new ExampleBasedPrecision(strict));
measures.add(new ExampleBasedRecall(strict));
measures.add(new ExampleBasedFMeasure(strict));
measures.add(new ExampleBasedAccuracy(strict));
// add label-based measures
int numOfLabels = data.getNumLabels();
measures.add(new MicroPrecision(numOfLabels));
measures.add(new MicroRecall(numOfLabels));
measures.add(new MicroFMeasure(numOfLabels));
measures.add(new MacroPrecision(numOfLabels, strict));
measures.add(new MacroRecall(numOfLabels, strict));
measures.add(new MacroFMeasure(numOfLabels, strict));
}
// add ranking-based measures if applicable
if (prediction.hasRanking()) {
// add ranking based measures
measures.add(new AveragePrecision());
measures.add(new Coverage());
measures.add(new OneError());
measures.add(new IsError());
measures.add(new ErrorSetSize());
measures.add(new RankingLoss());
}
// add confidence measures if applicable
if (prediction.hasConfidences()) {
int numOfLabels = data.getNumLabels();
measures.add(new MeanAveragePrecision(numOfLabels));
measures.add(new MicroAUC(numOfLabels));
measures.add(new MacroAUC(numOfLabels));
}
// add hierarchical measures if applicable
if (data.getLabelsMetaData().isHierarchy()) {
measures.add(new HierarchicalLoss(data));
}
} catch (Exception ex) {
Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
}
return measures;
}
private boolean[] getTrueLabels(Instance instance, int numLabels, int[] labelIndices) {
boolean[] trueLabels = new boolean[numLabels];
for (int counter = 0; counter < numLabels; counter++) {
int classIdx = labelIndices[counter];
String classValue = instance.attribute(classIdx).value((int) instance.value(classIdx));
trueLabels[counter] = classValue.equals("1");
}
return trueLabels;
}
/**
* Evaluates a {@link MultiLabelLearner} via cross-validation on given data
* set with defined number of folds and seed.
*
* @param learner the learner to be evaluated via cross-validation
* @param data the multi-label data set for cross-validation
* @param someFolds
* @return a {@link MultipleEvaluation} object holding the results
*/
public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, int someFolds)
{
checkLearner(learner);
checkData(data);
checkFolds(someFolds);
return innerCrossValidate(learner, data, false, null, someFolds);
}
/**
* Evaluates a {@link MultiLabelLearner} via cross-validation on given data
* set using given evaluation measures with defined number of folds and seed.
*
* @param learner the learner to be evaluated via cross-validation
* @param data the multi-label data set for cross-validation
* @param measures the evaluation measures to compute
* @param someFolds
* @return a {@link MultipleEvaluation} object holding the results
*/
public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures, int someFolds)
{
checkLearner(learner);
checkData(data);
checkMeasures(measures);
return innerCrossValidate(learner, data, true, measures, someFolds);
}
private MultipleEvaluation innerCrossValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List<Measure> measures, int someFolds) {
Evaluation[] evaluation = new Evaluation[someFolds];
Instances workingSet = new Instances(data.getDataSet());
workingSet.randomize(new Random(seed));
for (int i = 0; i < someFolds; i++) {
System.out.println("Fold " + (i + 1) + "/" + someFolds);
try {
Instances train = workingSet.trainCV(someFolds, i);
Instances test = workingSet.testCV(someFolds, i);
MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
MultiLabelLearner clone = learner.makeCopy();
clone.build(mlTrain);
if (hasMeasures)
evaluation[i] = evaluate(clone, mlTest, measures);
else
evaluation[i] = evaluate(clone, mlTest);
} catch (Exception ex) {
Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
}
}
return new MultipleEvaluation(evaluation);
}
}