/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* RCut.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.meta.thresholding;
import mulan.classifier.meta.*;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.*;
import mulan.core.MulanRuntimeException;
import mulan.data.InvalidDataFormatException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import weka.core.Utils;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
/**
* RCut(Rank-based cut): Selects the k top ranked labels for each instance,
* where k is a parameter provided by the user or automatically tuned.
*
* @author Marios Ioannou
* @author George Sakkas
* @author Grigorios Tsoumakas
* @version 2010.12.14
*/
/**
*
* <!-- globalinfo-start -->
*
* <pre>
* Class that implements the rank-based thresholding strategy.
* </pre>
*
* For more information:
*
* <pre>
* Yang, Y. (2001) A study of thresholding strategies for text categorization.
* Proceedings of the 24th annual international ACM SIGIR conference on Research
* and development in information retrieval, pp. 137-145.
* </pre>
*
* <!-- globalinfo-end -->
*
* <!-- technical-bibtex-start --> BibTeX:
*
* <pre>
* @inproceedings{yang:2001,
* author = {Yang, Y.},
* title = {A study of thresholding strategies for text categorization},
* booktitle = {Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval},
* year = {2001},
* pages = {137--145},
* }
* </pre>
*
* <p/> <!-- technical-bibtex-end -->
*
*/
public class RCut extends MultiLabelMetaLearner {
/** the top t number of labels to consider relevant */
private int t = 0;
/** measure for auto-tuning the threshold */
private BipartitionMeasureBase measure;
/** the folds of the cv to evaluate different thresholds */
private int folds;
/** copy of a clean multi-label learner to use at each fold */
private MultiLabelLearner foldLearner;
/**
* Creates a new instance of RCut
*
* @param baseLearner the underlying multi-label learner
*/
public RCut(MultiLabelLearner baseLearner) {
super(baseLearner);
}
/**
* Creates a new instance of RCut
*
* @param baseLearner the underlying multi-label learner
* @param aMeasure measure to optimize
* @param someFolds cross-validation folds
*/
public RCut(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure, int someFolds) {
super(baseLearner);
measure = aMeasure;
folds = someFolds;
try {
foldLearner = baseLearner.makeCopy();
} catch (Exception ex) {
Logger.getLogger(RCut.class.getName()).log(Level.SEVERE, null, ex);
}
}
/**
* Creates a new instance of RCut
*
* @param baseLearner the underlying multi-label learner
* @param aMeasure measure to optimize
*/
public RCut(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure) {
super(baseLearner);
measure = aMeasure;
}
/**
* Automatically selects a threshold based on training set performance
* evaluated using cross-validation
*
* @param measure performance is evaluated based on this parameter
* @param folds number of cross-validation folds
* @throws InvalidDataFormatException
* @throws Exception
*/
private void autoTuneThreshold(MultiLabelInstances trainingData, BipartitionMeasureBase measure, int folds) throws InvalidDataFormatException, Exception {
if (folds < 2) {
throw new IllegalArgumentException("folds should be more than 1");
}
double[] totalDiff = new double[numLabels + 1];
LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
MultiLabelLearner tempLearner = foldLearner.makeCopy();
for (int f = 0; f < folds; f++) {
Instances train = trainingData.getDataSet().trainCV(folds, f);
MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
Instances test = trainingData.getDataSet().testCV(folds, f);
MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
tempLearner.build(trainMulti);
double[] diff = computeThreshold(tempLearner, testMulti, measure);
for (int k = 0; k < diff.length; k++) {
totalDiff[k] += diff[k];
}
}
t = Utils.minIndex(totalDiff);
}
/**
* Evaluates the performance of different threshold values
*
* @param data the test data to evaluate different thresholds
* @param measure the evaluation is based on this parameter
* @return the sum of differences from the optimal value of the measure for
* each instance and threshold
* @throws Exception
*/
private double[] computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure) throws Exception {
double[] diff = new double[numLabels + 1];
measure.reset();
for (int j = 0; j < data.getNumInstances(); j++) {
Instance instance = data.getDataSet().instance(j);
if (data.hasMissingLabels(instance)) {
continue;
}
MultiLabelOutput mlo = learner.makePrediction(instance);
boolean[] trueLabels = new boolean[numLabels];
for (int counter = 0; counter < numLabels; counter++) {
int classIdx = labelIndices[counter];
String classValue = instance.attribute(classIdx).value((int) instance.value(classIdx));
trueLabels[counter] = classValue.equals("1");
}
int[] ranking = mlo.getRanking();
for (int threshold = 0; threshold <= numLabels; threshold++) {
boolean[] bipartition = new boolean[numLabels];
for (int k = 0; k < numLabels; k++) {
if (ranking[k] <= threshold) {
bipartition[k] = true;
}
}
// this doesn't work with label-based measures
// diff[threshold] += Math.abs(measure.getIdealValue() - measure.updateBipartition(bipartition, trueLabels));
}
}
return diff;
}
protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
baseLearner.build(trainingData);
MultiLabelOutput mlo = baseLearner.makePrediction(trainingData.getDataSet().firstInstance());
if (!mlo.hasRanking()) {
throw new MulanRuntimeException("Learner is not a ranker");
}
// by default set threshold equal to the rounded average cardinality
if (measure == null) {
t = (int) Math.round(trainingData.getCardinality());
} else {
// hold a reference to the trainingData in case of auto-tuning
if (folds == 0) {
double[] diff = computeThreshold(baseLearner, trainingData, measure);
t = Utils.minIndex(diff);
} else {
autoTuneThreshold(trainingData, measure, folds);
}
}
}
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Yiming Yang");
result.setValue(Field.TITLE, "A study of thresholding strategies for text categorization");
result.setValue(Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
result.setValue(Field.PAGES, "137 - 145");
result.setValue(Field.LOCATION, "New Orleans, Louisiana, United States");
result.setValue(Field.YEAR, "2001");
return result;
}
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
boolean[] predictedLabels;
MultiLabelOutput mlo = baseLearner.makePrediction(instance);
int[] ranking = mlo.getRanking();
predictedLabels = new boolean[numLabels];
for (int i = 0; i < numLabels; i++) {
if (ranking[i] <= t) {
predictedLabels[i] = true;
} else {
predictedLabels[i] = false;
}
}
MultiLabelOutput newOutput = new MultiLabelOutput(predictedLabels,
mlo.getConfidences());
return newOutput;
}
}