package edu.isistan.daclassifier;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.*;
import weka.core.Instance;
import weka.core.Instances;
public class CustomEvaluator {
// seed for reproduction of cross-validation results
public int seed = 1;
// when false divisions-by-zero are ignored in certain measures
private boolean strict = true;
/**
* Sets the seed for reproduction of cross-validation results
*
* @param aSeed seed for reproduction of cross-validation results
*/
public void setSeed(int aSeed) {
seed = aSeed;
}
/**
* Controls how divisions-by-zero are handled
*
* @param isStrict when false divisions-by-zero are ignored
*/
public void setStrict(boolean isStrict) {
strict = isStrict;
}
/**
* Evaluates a {@link MultiLabelLearner} on given test data set using specified evaluation measures
*
* @param learner the learner to be evaluated via cross-validation
* @param data the data set for cross-validation
* @param measures the evaluation measures to compute
* @return an Evaluation object
* @throws IllegalArgumentException if an input parameter is null
* @throws Exception
*/
public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures) throws IllegalArgumentException, Exception {
checkLearner(learner);
checkData(data);
checkMeasures(measures);
// reset measures
for (Measure m : measures) {
m.reset();
}
int numLabels = data.getNumLabels();
int[] labelIndices = data.getLabelIndices();
boolean[] trueLabels = new boolean[numLabels];
Set<Measure> failed = new HashSet<Measure>();
Instances testData = data.getDataSet();
int numInstances = testData.numInstances();
for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) {
Instance instance = testData.instance(instanceIndex);
if (data.hasMissingLabels(instance)) {
continue;
}
MultiLabelOutput output = learner.makePrediction(instance);
trueLabels = getTrueLabels(instance, numLabels, labelIndices);
Iterator<Measure> it = measures.iterator();
while (it.hasNext()) {
Measure m = it.next();
if (!failed.contains(m)) {
try {
m.update(output, trueLabels);
} catch (Exception ex) {
failed.add(m);
}
}
}
}
return new Evaluation(measures);
}
private void checkLearner(MultiLabelLearner learner) {
if (learner == null) {
throw new IllegalArgumentException("Learner to be evaluated is null.");
}
}
private void checkData(MultiLabelInstances data) {
if (data == null) {
throw new IllegalArgumentException("Evaluation data object is null.");
}
}
private void checkMeasures(List<Measure> measures) {
if (measures == null) {
throw new IllegalArgumentException("List of evaluation measures to compute is null.");
}
}
private void checkFolds(int someFolds) {
if (someFolds < 2) {
throw new IllegalArgumentException("Number of folds must be at least two or higher.");
}
}
/**
* Evaluates a {@link MultiLabelLearner} on given test data set.
*
* @param learner the learner to be evaluated
* @param data the data set for evaluation
* @return the evaluation result
* @throws IllegalArgumentException if either of input parameters is null.
* @throws Exception
*/
public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data) throws IllegalArgumentException, Exception {
checkLearner(learner);
checkData(data);
List<Measure> measures = prepareMeasures(learner, data);
return evaluate(learner, data, measures);
}
private List<Measure> prepareMeasures(MultiLabelLearner learner, MultiLabelInstances data) {
List<Measure> measures = new ArrayList<Measure>();
MultiLabelOutput prediction;
try {
MultiLabelLearner copyOfLearner = learner.makeCopy();
prediction = copyOfLearner.makePrediction(data.getDataSet().instance(0));
// add bipartition-based measures if applicable
if (prediction.hasBipartition()) {
// add example-based measures
measures.add(new HammingLoss());
measures.add(new SubsetAccuracy());
measures.add(new ExampleBasedPrecision(strict));
measures.add(new ExampleBasedRecall(strict));
measures.add(new ExampleBasedFMeasure(strict));
measures.add(new ExampleBasedAccuracy(strict));
// add label-based measures
int numOfLabels = data.getNumLabels();
measures.add(new MicroPrecision(numOfLabels));
measures.add(new MicroRecall(numOfLabels));
measures.add(new MicroFMeasure(numOfLabels));
measures.add(new MacroPrecision(numOfLabels, strict));
measures.add(new MacroRecall(numOfLabels, strict));
measures.add(new MacroFMeasure(numOfLabels, strict));
}
// add ranking-based measures if applicable
if (prediction.hasRanking()) {
// add ranking based measures
measures.add(new AveragePrecision());
measures.add(new Coverage());
measures.add(new OneError());
measures.add(new IsError());
measures.add(new ErrorSetSize());
measures.add(new RankingLoss());
}
// add confidence measures if applicable
if (prediction.hasConfidences()) {
int numOfLabels = data.getNumLabels();
measures.add(new MeanAveragePrecision(numOfLabels));
measures.add(new MicroAUC(numOfLabels));
measures.add(new MacroAUC(numOfLabels));
}
// add hierarchical measures if applicable
if (data.getLabelsMetaData().isHierarchy()) {
measures.add(new HierarchicalLoss(data));
}
} catch (Exception ex) {
Logger.getLogger(CustomEvaluator.class.getName()).log(Level.SEVERE, null, ex);
}
return measures;
}
private boolean[] getTrueLabels(Instance instance, int numLabels, int[] labelIndices) {
boolean[] trueLabels = new boolean[numLabels];
for (int counter = 0; counter < numLabels; counter++) {
int classIdx = labelIndices[counter];
String classValue = instance.attribute(classIdx).value((int) instance.value(classIdx));
trueLabels[counter] = classValue.equals("1");
}
return trueLabels;
}
private MultipleEvaluation trainingMultipleEvaluation;
private MultipleEvaluation testingMultipleEvaluation;
public MultipleEvaluation getTrainingMultipleEvaluation() {
return trainingMultipleEvaluation;
}
public MultipleEvaluation getTestingMultipleEvaluation() {
return testingMultipleEvaluation;
}
/**
* Evaluates a {@link MultiLabelLearner} via cross-validation on given data
* set with defined number of folds and seed.
*
* @param learner the learner to be evaluated via cross-validation
* @param data the multi-label data set for cross-validation
* @param someFolds
* @return a {@link MultipleEvaluation} object holding the results
*/
public void crossValidate(MultiLabelLearner learner, MultiLabelInstances data, int someFolds)
{
checkLearner(learner);
checkData(data);
checkFolds(someFolds);
innerCrossValidate(learner, data, false, null, someFolds);
}
/**
* Evaluates a {@link MultiLabelLearner} via cross-validation on given data
* set using given evaluation measures with defined number of folds and seed.
*
* @param learner the learner to be evaluated via cross-validation
* @param data the multi-label data set for cross-validation
* @param measures the evaluation measures to compute
* @param someFolds
* @return a {@link MultipleEvaluation} object holding the results
*/
public void crossValidate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures, int someFolds)
{
checkLearner(learner);
checkData(data);
checkMeasures(measures);
innerCrossValidate(learner, data, true, measures, someFolds);
}
private void innerCrossValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List<Measure> measures, int someFolds) {
Evaluation[] trainingEvaluation = new Evaluation[someFolds];
Evaluation[] testingEvaluation = new Evaluation[someFolds];
Instances workingSet = new Instances(data.getDataSet());
workingSet.randomize(new Random(seed));
for (int i = 0; i < someFolds; i++) {
System.out.println("Fold " + (i + 1) + "/" + someFolds);
try {
Instances train = workingSet.trainCV(someFolds, i);
Instances test = workingSet.testCV(someFolds, i);
MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
MultiLabelLearner clone = learner.makeCopy();
clone.build(mlTrain);
if (hasMeasures) {
trainingEvaluation[i] = evaluate(clone, mlTrain, measures);
testingEvaluation[i] = evaluate(clone, mlTest, measures);
}
else {
trainingEvaluation[i] = evaluate(clone, mlTrain);
testingEvaluation[i] = evaluate(clone, mlTest);
}
} catch (Exception ex) {
Logger.getLogger(CustomEvaluator.class.getName()).log(Level.SEVERE, null, ex);
}
}
trainingMultipleEvaluation = new MultipleEvaluation(trainingEvaluation);
testingMultipleEvaluation = new MultipleEvaluation(testingEvaluation);
}
private void checkPercentage(int randomPercentage, int increment) {
if (randomPercentage < 0 || randomPercentage > 100) {
throw new IllegalArgumentException("Percentage must be greater than zero and equal-or-less than one hundred.");
}
if(increment < 0 || increment > 100 || increment > randomPercentage) {
throw new IllegalArgumentException("Increment must be greater than zero, equal-or-less than one hundred and lower than the random percentage.");
}
}
public void randomPercentageValidate(MultiLabelLearner learner, MultiLabelInstances data, int randomPercentage, int increment)
{
checkLearner(learner);
checkData(data);
checkPercentage(randomPercentage, increment);
innerRandomPercentageValidate(learner, data, false, null, randomPercentage, increment);
}
/**
* Evaluates a {@link MultiLabelLearner} via cross-validation on given data
* set using given evaluation measures with defined number of folds and seed.
*
* @param learner the learner to be evaluated via cross-validation
* @param data the multi-label data set for cross-validation
* @param measures the evaluation measures to compute
* @param someFolds
* @return a {@link MultipleEvaluation} object holding the results
*/
public void randomPercentageValidate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures, int randomPercentage, int increment)
{
checkLearner(learner);
checkData(data);
checkMeasures(measures);
checkPercentage(randomPercentage, increment);
innerRandomPercentageValidate(learner, data, true, measures, randomPercentage, increment);
}
private void innerRandomPercentageValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List<Measure> measures, int randomPercentage, int increment) {
List<Evaluation> trainingEvaluation = new ArrayList<Evaluation>();
List<Evaluation> testingEvaluation = new ArrayList<Evaluation>();
//
Instances workingSet = new Instances(data.getDataSet());
workingSet.randomize(new Random(seed));
// Calculating values for sampling according to the percentage
int fullCapacity = workingSet.size();
double newPercentage = ((double) randomPercentage) / 100d;
int trainCapacity = (int) (newPercentage * fullCapacity);
int testCapacity = (int) ((1 - newPercentage) * fullCapacity);
// Calculating values for different slices of the dataset for train & test
int pointer = 0;
int pointerCount = 0;
double newIncrement = ((double) increment) / 100d;
int jumpLength = (int) (newIncrement * fullCapacity);
//
while(pointer + trainCapacity < fullCapacity) {
System.out.println("Pointer " + (pointerCount + 1) + " / " + "Percentage " + randomPercentage);
try {
Instances train = new Instances(data.getDataSet(), trainCapacity);
Instances test = new Instances(data.getDataSet(), testCapacity);
// Filling the test and train datasets with instances
for(int index = 0; index < fullCapacity; index++) {
Instance instance = workingSet.get(index);
if(index >= pointer && index < pointer + trainCapacity)
train.add(instance);
else
test.add(instance);
}
MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
MultiLabelLearner clone = learner.makeCopy();
clone.build(mlTrain);
if (hasMeasures) {
trainingEvaluation.add(evaluate(clone, mlTrain, measures));
testingEvaluation.add(evaluate(clone, mlTest, measures));
}
else {
trainingEvaluation.add(evaluate(clone, mlTrain));
testingEvaluation.add(evaluate(clone, mlTest));
}
pointer += jumpLength;
pointerCount++;
}
catch (Exception ex) {
Logger.getLogger(CustomEvaluator.class.getName()).log(Level.SEVERE, null, ex);
}
}
//
trainingMultipleEvaluation = new MultipleEvaluation(trainingEvaluation.toArray(new Evaluation[] { }));
testingMultipleEvaluation = new MultipleEvaluation(testingEvaluation.toArray(new Evaluation[] { }));
}
}