/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LabelPowerset.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.transformation;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelOutput;
import mulan.data.LabelSet;
import mulan.core.Util;
import mulan.data.MultiLabelInstances;
import mulan.transformations.LabelPowersetTransformation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
/**
* Class that implements a label powerset classifier <p>
*
* @author Grigorios Tsoumakas
* @author Robert Friberg
* @version $Revision: 0.05 $
*/
public class LabelPowerset extends TransformationBasedMultiLabelLearner {
/**
* The confidence values for each label are calculated in the following ways
* 0: Confidence 0 1/0 for all labels, (1 if label true, 0 if label is false)
* 1: Confidence of x/(1-x) for all labels, where x is the probability of the winning class (x if label true, (1-x) if label is false)
* 2: Confidence calculated based on the distribution of probabilities
* obtained from the base classifier, as introduced by the PPT algorithm
*/
private int confidenceCalculationMethod = 1;
/**
* Whether the method introduced by the PPT algorithm will be used to
* actually get the 1/0 output bipartition based on the confidences
* (requires a threshold)
*/
protected boolean makePredictionsBasedOnConfidences = false;
/**
* Threshold used for deciding the 1/0 output value of each label based on
* the corresponding confidences as calculated by the method introduced in
* the PPT algorithm
*/
protected double threshold = 0.5;
/** The object that performs the data transformation */
protected LabelPowersetTransformation transformation;
/**
* Random number generator for randomly solving tied predictions
*/
protected Random Rand;
/**
* Conststructor that initializes the learner with a base classifier
*
* @param classifier the base single-label classification algorithm
*/
public LabelPowerset(Classifier classifier) {
super(classifier);
Rand = new Random(1);
}
/**
* Sets a threshold for obtaining the bipartition
*
* @param value the threshold's value
*/
public void setMakePredictionsBasedOnConfidences(boolean value) {
makePredictionsBasedOnConfidences = value;
}
/**
* Setting a seed for random selection in case of ties during prediction
*
* @param s the seed
*/
public void setSeed(int s) {
Rand = new Random(s);
}
/**
* The threshold for obtaining the bipartition from probabilities
*
* @param t
*/
public void setThreshold(double t) {
threshold = t;
}
/**
* Sets the method of calculating probabilities for each label
*
* @param method
*/
public void setConfidenceCalculationMethod(int method) {
if (method == 0 || method == 1 || method == 2) {
confidenceCalculationMethod = method;
}
}
protected void buildInternal(MultiLabelInstances mlData) throws Exception {
Instances transformedData;
transformation = new LabelPowersetTransformation();
debug("Transforming the training set.");
transformedData = transformation.transformInstances(mlData);
//debug("Transformed training set: \n + transformedData.toString());
// check for unary class
debug("Building single-label classifier.");
if (transformedData.attribute(transformedData.numAttributes() - 1).numValues() > 1) {
baseClassifier.buildClassifier(transformedData);
}
}
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
boolean bipartition[] = null;
double confidences[] = null;
// check for unary class
if (transformation.getTransformedFormat().classAttribute().numValues() == 1) {
String strClass = transformation.getTransformedFormat().classAttribute().value(0);
LabelSet labelSet = null;
try {
labelSet = LabelSet.fromBitString(strClass);
} catch (Exception ex) {
Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
}
bipartition = labelSet.toBooleanArray();
confidences = labelSet.toDoubleArray();
} else {
double[] distribution = null;
try {
//debug("old instance:" + instance.toString());
Instance transformedInstance;
transformedInstance = transformation.transformInstance(instance, labelIndices);
distribution = baseClassifier.distributionForInstance(transformedInstance);
//debug(Arrays.toString(distribution));
} catch (Exception ex) {
Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
}
int classIndex = Util.RandomIndexOfMax(distribution, Rand);
//debug("" + classIndex);
String strClass = (transformation.getTransformedFormat().classAttribute()).value(classIndex);
LabelSet labelSet = null;
try {
labelSet = LabelSet.fromBitString(strClass);
} catch (Exception ex) {
Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
}
bipartition = labelSet.toBooleanArray();
//debug(Arrays.toString(bipartition));
switch (confidenceCalculationMethod) {
case 0:
confidences = Arrays.copyOf(labelSet.toDoubleArray(), labelSet.size());
break;
case 1:
confidences = new double[numLabels];
double prob = distribution[classIndex];
for (int i = 0; i < numLabels; i++) {
confidences[i] = bipartition[i] ? prob : 1 - prob;
}
break;
case 2:
confidences = new double[numLabels];
for (int i = 0; i < distribution.length; i++) {
strClass = (transformation.getTransformedFormat().classAttribute()).value(i);
try {
labelSet = LabelSet.fromBitString(strClass);
} catch (Exception ex) {
Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
}
double[] predictionsTemp = labelSet.toDoubleArray();
double confidence = distribution[i];
for (int j = 0; j < numLabels; j++) {
if (predictionsTemp[j] == 1) {
confidences[j] += confidence;
}
}
}
}
if (makePredictionsBasedOnConfidences) {
for (int i = 0; i < confidences.length; i++) {
if (confidences[i] > threshold) {
bipartition[i] = true;
} else {
bipartition[i] = false;
}
}
}
}
MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidences);
return mlo;
}
}