/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * RAkEL.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.meta; import java.util.Arrays; import java.util.HashSet; import java.util.Random; import mulan.classifier.MultiLabelLearner; import mulan.classifier.MultiLabelOutput; import mulan.data.MultiLabelInstances; import weka.core.Instance; import weka.core.Instances; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * * <!-- globalinfo-start --> * * <pre> * Class implementing a generalized version of the RAkEL (RAndom k-labELsets) algorithm. * </pre> * * For more information: * * <pre> * Tsoumakas, G, Vlahavas, I. (2007) Random k-Labelsets: An Ensemble Method * for Multilabel Classification", Proc. 18th European Conference on Machine * Learning (ECML 2007), pp. 406-417, Warsaw, Poland, 17-21 September 2007. * </pre> * * <!-- globalinfo-end --> * * <!-- technical-bibtex-start --> BibTeX: * * <pre> * @inproceedings{tsoumakas+vlahavas:2007, * author = {Tsoumakas, G. and Vlahavas, I.}, * title = {Random k-Labelsets: An Ensemble Method for Multilabel Classification}, * booktitle = {Proceedings of the 18th European Conference on Machine Learning (ECML 2007)}, * year = {2007}, * pages = {406--417}, * address = {Warsaw, Poland}, * month = {September 17-21}, * } * </pre> * * <p/> <!-- technical-bibtex-end --> * * @author Grigorios Tsoumakas * @version $Revision: 0.04 $ */ @SuppressWarnings("serial") public class RAkEL extends MultiLabelMetaLearner { /** * Seed for replication of random experiments */ private int seed = 0; /** * Random number generator */ private Random rnd; /** * If true then the confidence of the base classifier to the decisions... */ //private boolean useConfidences = true; double[][] sumVotesIncremental; /* comment */ double[][] lengthVotesIncremental; double[] sumVotes; double[] lengthVotes; int numOfModels; double threshold = 0.5; int sizeOfSubset = 3; int[][] classIndicesPerSubset; int[][] absoluteIndicesToRemove; MultiLabelLearner[] subsetClassifiers; protected Remove[] remove; HashSet<String> combinations; /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Grigorios Tsoumakas, Ioannis Vlahavas"); result.setValue(Field.TITLE, "Random k-Labelsets: An Ensemble Method for Multilabel Classification"); result.setValue(Field.BOOKTITLE, "Proc. 18th European Conference on Machine Learning (ECML 2007)"); result.setValue(Field.PAGES, "406 - 417"); result.setValue(Field.LOCATION, "Warsaw, Poland"); result.setValue(Field.MONTH, "17-21 September"); result.setValue(Field.YEAR, "2007"); return result; } public RAkEL(int models, int subset) throws Exception { sizeOfSubset = subset; setNumModels(models); } public RAkEL(MultiLabelLearner baseLearner) { super(baseLearner); } public RAkEL(MultiLabelLearner baseLearner, int models, int subset) { super(baseLearner); sizeOfSubset = subset; setNumModels(models); } public RAkEL(MultiLabelLearner baseLearner, int models, int subset, double threshold) { super(baseLearner); sizeOfSubset = subset; setNumModels(models); this.threshold = threshold; } public void setSeed(int x) { seed = x; } public void setSizeOfSubset(int size) { sizeOfSubset = size; classIndicesPerSubset = new int[numOfModels][sizeOfSubset]; } public int getSizeOfSubset() { return sizeOfSubset; } public void setNumModels(int models) { numOfModels = models; } public int getNumModels() { return numOfModels; } public static int binomial(int n, int m) { int[] b = new int[n + 1]; b[0] = 1; for (int i = 1; i <= n; i++) { b[i] = 1; for (int j = i - 1; j > 0; --j) { b[j] += b[j - 1]; } } return b[m]; } @Override protected void buildInternal(MultiLabelInstances trainingData) throws Exception { rnd = new Random(seed); // need a structure to hold different combinations combinations = new HashSet<String>(); //MultiLabelInstances mlDataSet = trainData.clone(); // default number of models = twice the number of labels if (numOfModels == 0) { numOfModels = Math.min(2 * numLabels, binomial(numLabels, sizeOfSubset)); } classIndicesPerSubset = new int[numOfModels][sizeOfSubset]; absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset]; subsetClassifiers = new MultiLabelLearner[numOfModels]; remove = new Remove[numOfModels]; for (int i = 0; i < numOfModels; i++) { updateClassifier(trainingData, i); } } private void updateClassifier(MultiLabelInstances mlTrainData, int model) throws Exception { //todo: check if the following is unnecessary (was used for cvparam) if (combinations == null) { combinations = new HashSet<String>(); } Instances trainData = mlTrainData.getDataSet(); // select a random subset of classes not seen before // todo: select according to inverse distribution of current selection boolean[] selected; do { selected = new boolean[numLabels]; for (int j = 0; j < sizeOfSubset; j++) { int randomLabel; randomLabel = rnd.nextInt(numLabels); while (selected[randomLabel] != false) { randomLabel = rnd.nextInt(numLabels); } selected[randomLabel] = true; //System.out.println("label: " + randomLabel); classIndicesPerSubset[model][j] = randomLabel; } Arrays.sort(classIndicesPerSubset[model]); } while (combinations.add(Arrays.toString(classIndicesPerSubset[model])) == false); debug("Building model " + (model + 1) + "/" + numOfModels + ", subset: " + Arrays.toString(classIndicesPerSubset[model])); // remove the unselected labels absoluteIndicesToRemove[model] = new int[numLabels - sizeOfSubset]; int k = 0; for (int j = 0; j < numLabels; j++) { if (selected[j] == false) { absoluteIndicesToRemove[model][k] = labelIndices[j]; k++; } } remove[model] = new Remove(); remove[model].setAttributeIndicesArray(absoluteIndicesToRemove[model]); remove[model].setInputFormat(trainData); remove[model].setInvertSelection(false); Instances trainSubset = Filter.useFilter(trainData, remove[model]); // build a MultiLabelLearner for the selected label subset; subsetClassifiers[model] = getBaseLearner().makeCopy(); subsetClassifiers[model].build(mlTrainData.reintegrateModifiedDataSet(trainSubset)); } protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { double[] sumConf = new double[numLabels]; sumVotes = new double[numLabels]; lengthVotes = new double[numLabels]; // gather votes for (int i = 0; i < numOfModels; i++) { remove[i].input(instance); remove[i].batchFinished(); Instance newInstance = remove[i].output(); MultiLabelOutput subsetMLO = subsetClassifiers[i].makePrediction(newInstance); for (int j = 0; j < sizeOfSubset; j++) { sumConf[classIndicesPerSubset[i][j]] += subsetMLO.getConfidences()[j]; sumVotes[classIndicesPerSubset[i][j]] += subsetMLO.getBipartition()[j] ? 1 : 0; lengthVotes[classIndicesPerSubset[i][j]]++; } } double[] confidence1 = new double[numLabels]; double[] confidence2 = new double[numLabels]; boolean[] bipartition = new boolean[numLabels]; for (int i = 0; i < numLabels; i++) { if (lengthVotes[i] != 0) { confidence1[i] = sumVotes[i] / lengthVotes[i]; confidence2[i] = sumConf[i] / lengthVotes[i]; } else { confidence1[i] = 0; confidence2[i] = 0; } if (confidence1[i] >= threshold) { bipartition[i] = true; } else { bipartition[i] = false; } } // todo: optionally use confidence2 for ranking measures MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidence1); return mlo; } }