/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ClassifierChain.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.transformation;
import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.unsupervised.attribute.Remove;
/**
*
* <!-- globalinfo-start -->
* <!-- globalinfo-end -->
*
* <!-- technical-bibtex-start -->
* <!-- technical-bibtex-end -->
*
* @author Eleftherios Spyromitros-Xioufis ( espyromi@csd.auth.gr )
* @author Konstantinos Sechidis (sechidis@csd.auth.gr)
* @author Grigorios Tsoumakas (greg@csd.auth.gr)
*/
public class ClassifierChain extends TransformationBasedMultiLabelLearner {
/**
* The new chain ordering of the label indices
*/
private int[] chain;
/**
* Returns a string describing classifier.
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class implementing the Classifier Chains for Multi-label Classification algorithm." + "\n\n" + "For more information, see\n\n" + getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing detailed
* information about the technical background of this class, e.g., paper
* reference or book this class is based on.
*
* @return the technical information about this class
*/
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff and Frank, Eibe");
result.setValue(Field.TITLE, "Classifier Chains for Multi-label Classification");
result.setValue(Field.VOLUME, "Proceedings of ECML/PKDD 2009");
result.setValue(Field.YEAR, "2009");
result.setValue(Field.PAGES, "254--269");
result.setValue(Field.ADDRESS, "Bled, Slovenia");
return result;
}
/**
* The ensemble of binary relevance models. These are Weka
* FilteredClassifier objects, where the filter corresponds to removing all
* label apart from the one that serves as a target for the corresponding
* model.
*/
protected FilteredClassifier[] ensemble;
/**
* Creates a new instance
*
* @param classifier the base-level classification algorithm that will be
* used for training each of the binary models
* @param aChain
*/
public ClassifierChain(Classifier classifier, int[] aChain) {
super(classifier);
chain = aChain;
}
/**
* Creates a new instance
*
* @param classifier the base-level classification algorithm that will be
* used for training each of the binary models
*/
public ClassifierChain(Classifier classifier) {
super(classifier);
}
protected void buildInternal(MultiLabelInstances train) throws Exception {
if (chain == null) {
chain = new int[numLabels];
for (int i = 0; i < numLabels; i++) {
chain[i] = i;
}
}
Instances trainDataset;
numLabels = train.getNumLabels();
ensemble = new FilteredClassifier[numLabels];
trainDataset = train.getDataSet();
for (int i = 0; i < numLabels; i++) {
ensemble[i] = new FilteredClassifier();
ensemble[i].setClassifier(AbstractClassifier.makeCopy(baseClassifier));
// Indices of attributes to remove first removes numLabels attributes
// the numLabels - 1 attributes and so on.
// The loop starts from the last attribute.
int[] indicesToRemove = new int[numLabels - 1 - i];
int counter2 = 0;
for (int counter1 = 0; counter1 < numLabels - i - 1; counter1++) {
indicesToRemove[counter1] = labelIndices[chain[numLabels - 1 - counter2]];
counter2++;
}
Remove remove = new Remove();
remove.setAttributeIndicesArray(indicesToRemove);
remove.setInputFormat(trainDataset);
remove.setInvertSelection(false);
ensemble[i].setFilter(remove);
trainDataset.setClassIndex(labelIndices[chain[i]]);
debug("Bulding model " + (i + 1) + "/" + numLabels);
ensemble[i].buildClassifier(trainDataset);
}
}
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
boolean[] bipartition = new boolean[numLabels];
double[] confidences = new double[numLabels];
Instance tempInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());
for (int counter = 0; counter < numLabels; counter++) {
double distribution[] = new double[2];
try {
distribution = ensemble[counter].distributionForInstance(tempInstance);
} catch (Exception e) {
System.out.println(e);
return null;
}
int maxIndex = (distribution[0] > distribution[1]) ? 0 : 1;
// Ensure correct predictions both for class values {0,1} and {1,0}
Attribute classAttribute = ensemble[counter].getFilter().getOutputFormat().classAttribute();
bipartition[chain[counter]] = (classAttribute.value(maxIndex).equals("1")) ? true : false;
// The confidence of the label being equal to 1
confidences[chain[counter]] = distribution[classAttribute.indexOfValue("1")];
tempInstance.setValue(labelIndices[chain[counter]], maxIndex);
}
MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidences);
return mlo;
}
}