/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* OneThreshold.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.meta.thresholding;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.*;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.core.MulanRuntimeException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.*;
import weka.core.Utils;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
/**
* @author Marios Ioannou
* @author George Sakkas
* @author Grigorios Tsoumakas
* @version 2010.12.14
*/
public class OneThreshold extends MultiLabelMetaLearner {
/** final threshold value */
private double threshold;
/** measure for auto-tuning the threshold */
private BipartitionMeasureBase measure;
/** the folds of the cv to evaluate different thresholds */
private int folds = 0;
/** copy of a clean multi-label learner to use at each fold */
private MultiLabelLearner foldLearner;
/**
* @param baseLearner the underlying multi=label learner
* @param aMeasure the measure to optimize
* @param someFolds number of cross-validation folds
*/
public OneThreshold(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure, int someFolds) {
super(baseLearner);
if (someFolds < 2) {
throw new IllegalArgumentException("folds should be more than 1");
}
measure = aMeasure;
folds = someFolds;
try {
foldLearner = baseLearner.makeCopy();
} catch (Exception ex) {
Logger.getLogger(OneThreshold.class.getName()).log(Level.SEVERE, null, ex);
}
}
/**
* @param baseLearner the underlying multi=label learner
* @param aMeasure measure to optimize
*/
public OneThreshold(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure) {
super(baseLearner);
measure = aMeasure;
}
/**
* Evaluates the performance of the learner on a data set according to a
* bipartition measure for a range of thresholds
*
* @param data the test data to evaluate different thresholds
* @param measure the evaluation is based on this parameter
* @param min the minimum threshold
* @param max the maximum threshold
* @param the step to increase threshold from min to max
* @return the optimal threshold
* @throws Exception
*/
private double computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure, double min, double step, double max) throws Exception {
int numOfThresholds = (int) Math.rint((max - min) / step + 1);
double[] performance = new double[numOfThresholds];
BipartitionMeasureBase[] measureForThreshold = new BipartitionMeasureBase[numOfThresholds];
for (int i = 0; i < numOfThresholds; i++) {
measureForThreshold[i] = (BipartitionMeasureBase) measure.makeCopy();
measureForThreshold[i].reset();
}
boolean[] thresholdHasProblem = new boolean[numOfThresholds];
Arrays.fill(thresholdHasProblem, false);
for (int j = 0; j < data.getNumInstances(); j++) {
Instance instance = data.getDataSet().instance(j);
if (data.hasMissingLabels(instance)) {
continue;
}
MultiLabelOutput mlo = learner.makePrediction(instance);
boolean[] trueLabels = new boolean[numLabels];
for (int counter = 0; counter < numLabels; counter++) {
int classIdx = labelIndices[counter];
String classValue = instance.attribute(classIdx).value((int) instance.value(classIdx));
trueLabels[counter] = classValue.equals("1");
}
double[] confidences = mlo.getConfidences();
int counter = 0;
double currentThreshold = min;
while (currentThreshold <= max) {
boolean[] bipartition = new boolean[numLabels];
for (int k = 0; k < numLabels; k++) {
if (confidences[k] >= currentThreshold) {
bipartition[k] = true;
}
}
try {
MultiLabelOutput temp = new MultiLabelOutput(bipartition);
measureForThreshold[counter].update(temp, trueLabels);
} catch (MulanRuntimeException e) {
thresholdHasProblem[counter] = true;
}
currentThreshold += step;
counter++;
}
}
for (int i = 0; i < numOfThresholds; i++) {
if (!thresholdHasProblem[i])
performance[i] = Math.abs(measure.getIdealValue() - measureForThreshold[i].getValue());
else
performance[i] = Double.MAX_VALUE;
}
return min + Utils.minIndex(performance) * step;
}
/**
* Evaluates the measureForThreshold of different threshold values
*
* @param data the test data to evaluate different thresholds
* @param measure the evaluation is based on this parameter
* @return the sum of differences from the optimal value of the measure for
* each instance and threshold
* @throws Exception
*/
private double computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure) throws Exception {
double stage1 = computeThreshold(learner, data, measure, 0, 0.1, 1);
debug("1st stage threshold = " + stage1);
double stage2 = computeThreshold(learner, data, measure, stage1 - 0.05, 0.01, stage1 + 0.05);
debug("2nd stage threshold = " + stage2);
return stage2;
}
protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
baseLearner.build(trainingData);
if (folds == 0) {
threshold = computeThreshold(baseLearner, trainingData, measure);
} else {
LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
double[] thresholds = new double[folds];
for (int f = 0; f < folds; f++) {
Instances train = trainingData.getDataSet().trainCV(folds, f);
MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
Instances test = trainingData.getDataSet().testCV(folds, f);
MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
MultiLabelLearner tempLearner = foldLearner.makeCopy();
tempLearner.build(trainMulti);
thresholds[f] = computeThreshold(tempLearner, testMulti, measure);
}
threshold = Utils.mean(thresholds);
}
}
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
boolean[] predictedLabels;
MultiLabelOutput mlo = baseLearner.makePrediction(instance);
double[] confidences = mlo.getConfidences();
predictedLabels = new boolean[numLabels];
for (int i = 0; i < numLabels; i++) {
if (confidences[i] >= threshold) {
predictedLabels[i] = true;
} else {
predictedLabels[i] = false;
}
}
MultiLabelOutput newOutput = new MultiLabelOutput(predictedLabels, mlo.getConfidences());
return newOutput;
}
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation info = new TechnicalInformation(Type.INPROCEEDINGS);
info.setValue(Field.AUTHOR, "Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff");
info.setValue(Field.YEAR, "2008");
info.setValue(Field.TITLE, "Multi-label Classification Using Ensembles of Pruned Sets");
info.setValue(Field.BOOKTITLE, "Data Mining, 2008. ICDM '08. Eighth IEEE International Conference on");
info.setValue(Field.PAGES, "995-1000");
info.setValue(Field.LOCATION, "Pisa, Italy");
return info;
}
/**
* Returns the calculated threshold
*
* @return the calculated threshold
*/
public double getThreshold() {
return threshold;
}
}