/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* EnsembleOfClassifierChains.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.transformation;
import java.util.Arrays;
import java.util.Random;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.RemovePercentage;
/**
*
* <!-- technical-bibtex-start --> <!-- technical-bibtex-end -->
*
* @author Eleftherios Spyromitros-Xioufis ( espyromi@csd.auth.gr )
* @author Konstantinos Sechidis (sechidis@csd.auth.gr)
* @author Grigorios Tsoumakas (greg@csd.auth.gr)
*
* @version 2010.10.11
*/
public class EnsembleOfClassifierChains extends TransformationBasedMultiLabelLearner {
/**
* The number of classifier chain models
*/
protected int numOfModels;
/**
* An array of ClassifierChain models
*/
protected ClassifierChain[] ensemble;
/**
* Random number generator
*/
protected Random rand;
/**
* Whether the output is computed based on the average votes or on the
* average confidences
*/
protected boolean useConfidences;
/**
* Whether to use sampling with replacement to create the data of the models
* of the ensemble
*/
protected boolean useSamplingWithReplacement = true;
/**
* The size of each bag sample, as a percentage of the training size. Used
* when useSamplingWithReplacement is true
*/
protected int BagSizePercent = 100;
public int getBagSizePercent() {
return BagSizePercent;
}
public void setBagSizePercent(int bagSizePercent) {
BagSizePercent = bagSizePercent;
}
public double getSamplingPercentage() {
return samplingPercentage;
}
public void setSamplingPercentage(double samplingPercentage) {
this.samplingPercentage = samplingPercentage;
}
/**
* The size of each sample, as a percentage of the training size Used when
* useSamplingWithReplacement is false
*/
protected double samplingPercentage = 67;
/**
* Creates a new object
*
* @param classifier
* the base classifier for each ClassifierChain model
* @param aNumOfModels
* the number of models
* @param doUseConfidences
* @param doUseSamplingWithReplacement
*/
public EnsembleOfClassifierChains(Classifier classifier, int aNumOfModels,
boolean doUseConfidences, boolean doUseSamplingWithReplacement) {
super(classifier);
numOfModels = aNumOfModels;
useConfidences = doUseConfidences;
useSamplingWithReplacement = doUseSamplingWithReplacement;
ensemble = new ClassifierChain[aNumOfModels];
rand = new Random(1);
}
/**
* Returns a string describing classifier.
*
* @return a description suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "Class implementing the Classifier Chains for Multi-label Classification algorithm."
+ "\n\n" + "For more information, see\n\n" + getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing detailed
* information about the technical background of this class, e.g., paper
* reference or book this class is based on.
*
* @return the technical information about this class
*/
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR,
"Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff and Frank, Eibe");
result.setValue(Field.TITLE, "Classifier Chains for Multi-label Classification");
result.setValue(Field.VOLUME, "Proceedings of ECML/PKDD 2009");
result.setValue(Field.YEAR, "2009");
result.setValue(Field.PAGES, "254--269");
result.setValue(Field.ADDRESS, "Bled, Slovenia");
return result;
}
@Override
protected void buildInternal(MultiLabelInstances trainingSet) throws Exception {
Instances dataSet = new Instances(trainingSet.getDataSet());
for (int i = 0; i < numOfModels; i++) {
debug("ECC Building Model:" + (i + 1) + "/" + numOfModels);
Instances sampledDataSet = null;
dataSet.randomize(rand);
if (useSamplingWithReplacement) {
int bagSize = dataSet.numInstances() * BagSizePercent / 100;
// create the in-bag dataset
sampledDataSet = dataSet.resampleWithWeights(new Random(1));
if (bagSize < dataSet.numInstances()) {
sampledDataSet = new Instances(sampledDataSet, 0, bagSize);
}
} else {
RemovePercentage rmvp = new RemovePercentage();
rmvp.setInvertSelection(true);
rmvp.setPercentage(samplingPercentage);
rmvp.setInputFormat(dataSet);
sampledDataSet = Filter.useFilter(dataSet, rmvp);
}
MultiLabelInstances train = new MultiLabelInstances(sampledDataSet, trainingSet
.getLabelsMetaData());
int[] chain = new int[numLabels];
for (int j = 0; j < numLabels; j++)
chain[j] = j;
for (int j = 0; j < chain.length; j++) {
int randomPosition = rand.nextInt(chain.length);
int temp = chain[j];
chain[j] = chain[randomPosition];
chain[randomPosition] = temp;
}
debug(Arrays.toString(chain));
// MAYBE WE SHOULD CHECK NOT TO PRODUCE THE SAME VECTOR FOR THE
// INDICES
// BUT IN THE PAPER IT DID NOT MENTION SOMETHING LIKE THAT
// IT JUST SIMPLY SAY A RANDOM CHAIN ORDERING OF L
ensemble[i] = new ClassifierChain(baseClassifier, chain);
ensemble[i].build(train);
}
}
@Override
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception,
InvalidDataException {
int[] sumVotes = new int[numLabels];
double[] sumConf = new double[numLabels];
Arrays.fill(sumVotes, 0);
Arrays.fill(sumConf, 0);
for (int i = 0; i < numOfModels; i++) {
MultiLabelOutput ensembleMLO = ensemble[i].makePrediction(instance);
boolean[] bip = ensembleMLO.getBipartition();
double[] conf = ensembleMLO.getConfidences();
for (int j = 0; j < numLabels; j++) {
sumVotes[j] += bip[j] == true ? 1 : 0;
sumConf[j] += conf[j];
}
}
double[] confidence = new double[numLabels];
for (int j = 0; j < numLabels; j++) {
if (useConfidences)
confidence[j] = sumConf[j] / numOfModels;
else
confidence[j] = sumVotes[j] / (double) numOfModels;
}
MultiLabelOutput mlo = new MultiLabelOutput(confidence, 0.5);
return mlo;
}
}