/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SCut.java
* Copyright (C) 2009 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.meta.thresholding;
import mulan.classifier.meta.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.Utils;
/**
* Class that implements the SCut method (Score-based local optimization).
* It computes a separate threshold for each label based on improving a user defined
* performance measure.
*
* @author Marios Ioannou
* @author George Sakkas
* @author Grigorios Tsoumakas
* @version 2010.12.14
*/
public class SCut extends MultiLabelMetaLearner {
/** measure for auto-tuning the threshold */
BipartitionMeasureBase measure;
/** the folds of the cv to evaluate different thresholds */
int kFoldsCV;
/** one threshold for each label to consider relevant */
double[] thresholds;
/**
* Constructor that initializes the learner with a base algorithm , Measure and num of folds
*
* @param baseLearner the underlying multi-label learner
* @param measure
* @param folds the number of folds to split the dataset
*/
public SCut(MultiLabelLearner baseLearner, BipartitionMeasureBase measure, int folds) {
super(baseLearner);
this.measure = measure;
this.kFoldsCV = folds;
}
/**
* Creates a new instance of SCut
*
* @param baseLearner the underlying multi-label learner
* @param measure
*/
public SCut(MultiLabelLearner baseLearner, BipartitionMeasureBase measure) {
super(baseLearner);
this.measure = measure;
}
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Yiming Yang");
result.setValue(Field.TITLE, "A study of thresholding strategies for text categorization");
result.setValue(Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
result.setValue(Field.PAGES, "137 - 145");
result.setValue(Field.LOCATION, "New Orleans, Louisiana, United States");
result.setValue(Field.YEAR, "2001");
return result;
}
/**
* Evaluates the performance of different threshold values for each label
*
* @param baseLearner the underlying multi-label learner
* @param data the test data to evaluate different thresholds
* @return one threshold for each label
* @throws Exception
*/
private double[] computeThresholds(MultiLabelLearner learner, MultiLabelInstances data) throws Exception {
double[][] arraysOfConfidences = new double[data.getNumInstances()][numLabels];
boolean[][] trueLabels = new boolean[data.getNumInstances()][numLabels];
List<Double>[] conf = new ArrayList[numLabels];
for (int l = 0; l < numLabels; l++) {
conf[l] = new ArrayList();
}
//get the Confidences and TrueLabels from all instances
for (int j = 0; j < data.getNumInstances(); j++) {
try {
arraysOfConfidences[j] = learner.makePrediction(data.getDataSet().instance(j)).getConfidences();
} catch (Exception ex) {
Logger.getLogger(SCut.class.getName()).log(Level.SEVERE, null, ex);
}
for (int l = 0; l < numLabels; l++) {
int labelIndice = labelIndices[l];
trueLabels[j][l] = data.getDataSet().attribute(labelIndice).value((int) data.getDataSet().instance(j).value(labelIndice)).equals("1");
conf[l].add(arraysOfConfidences[j][l]);
}
}
double[] currentThresholds = new double[numLabels];
double[][] measureTable = new double[3][numLabels];
// sorting the confidences and set initial threshohlds for all labels
for (int l = 0; l < numLabels; l++) {
Collections.sort(conf[l]);
currentThresholds[l] = 0.5;
}
double counter = 0;
double tempThreshold = 0;
int conv = 0;
int numOfThresholds = data.getNumInstances();
double[] performance = new double[numOfThresholds];
BipartitionMeasureBase[] measureForThreshold = new BipartitionMeasureBase[numOfThresholds];
for (int i = 0; i < numOfThresholds; i++) {
measureForThreshold[i] = (BipartitionMeasureBase) measure.makeCopy();
measureForThreshold[i].reset();
}
do {
//get the old measures values
for (int j = 0; j < numLabels; j++) {
measureTable[1][j] = measureTable[0][j];
}
//for all labels computing the best thresholds
for (int j = 0; j < numLabels; j++) {
double score = 0;
//get a measure for all Thresholds
for (int l = numOfThresholds - 1; l >= 0; l--) //posa instances diladi tosa thresshold
{
measureForThreshold[l].reset();
if (l == 0) {
currentThresholds[j] = conf[j].get(l);
} else {
currentThresholds[j] = (conf[j].get(l) + conf[j].get(l - 1)) / 2;
}
//get the predicted labels for all instances according to Thresholds
for (int k = 0; k < data.getNumInstances(); k++) {
boolean[] predictedLabels = new boolean[numLabels];
for (int x = 0; x < numLabels; x++) {
predictedLabels[x] = (arraysOfConfidences[k][x] >= currentThresholds[x]);
}
MultiLabelOutput temp = new MultiLabelOutput(predictedLabels);
measureForThreshold[l].update(temp, trueLabels[k]);
}
score += measureForThreshold[l].getValue();
}
for (int i = 0; i < numOfThresholds; i++) {
performance[i] = Math.abs(measure.getIdealValue() - measureForThreshold[i].getValue());
}
int t = Utils.minIndex(performance);
if (t == 0) {
tempThreshold = conf[j].get(t);
} else {
tempThreshold = (conf[j].get(t) + conf[j].get(t - 1)) / 2;
}
// get the curent measure
measureTable[0][j] = score;
currentThresholds[j] = tempThreshold;
//get the first measure
if (counter == 0) {
measureTable[2][j] = score;
}
}
conv = 0;
// find if the two last mesures of all labels are converge
for (int l = 0; l < numLabels; l++) {
// (curent measure-old measure)/first measure
if ((Math.abs((measureTable[0][l] - measureTable[1][l]))) / measureTable[2][l] < 0.001 && counter != 0) {
conv++;
}
}
counter++;
} while (conv != numLabels);
return currentThresholds;
}
@Override
protected void buildInternal(MultiLabelInstances trainingSet) throws Exception {
if (kFoldsCV == 0) {
baseLearner.build(trainingSet);
thresholds = computeThresholds(baseLearner, trainingSet);
} else {
thresholds = new double[numLabels];
double[] foldThresholds;
for (int i = 0; i <
kFoldsCV; i++) {
//Split data to train and test sets
Instances train = trainingSet.getDataSet().trainCV(kFoldsCV, i);
MultiLabelInstances mlTrain = new MultiLabelInstances(train, trainingSet.getLabelsMetaData());
Instances test = trainingSet.getDataSet().testCV(kFoldsCV, i);
MultiLabelInstances mlTest = new MultiLabelInstances(test, trainingSet.getLabelsMetaData());
MultiLabelLearner learner = baseLearner.makeCopy();
learner.build(mlTrain);
foldThresholds =
computeThresholds(learner, mlTest);
for (int j = 0; j <
numLabels; j++) {
thresholds[j] += foldThresholds[j];
}
}
for (int j = 0; j <
numLabels; j++) {
thresholds[j] /= kFoldsCV;
}
baseLearner.build(trainingSet);
}
}
@Override
public MultiLabelOutput makePredictionInternal(
Instance instance) throws Exception {
MultiLabelOutput m = baseLearner.makePrediction(instance);
double[] arrayOfConfidences = new double[numLabels];
boolean[] predictedLabels = new boolean[numLabels];
//Confidences higher than threshold set it as true label
if (m.hasConfidences()) {
arrayOfConfidences = m.getConfidences();
for (int i = 0; i <
numLabels; i++) {
if (arrayOfConfidences[i] >= thresholds[i]) {
predictedLabels[i] = true;
} else {
predictedLabels[i] = false;
}
}
}
MultiLabelOutput final_mlo = new MultiLabelOutput(predictedLabels, arrayOfConfidences);
return final_mlo;
}
}