/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * SubsetLearner.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.meta; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import mulan.classifier.MultiLabelLearner; import mulan.classifier.MultiLabelOutput; import mulan.classifier.transformation.LabelPowerset; import mulan.core.ArgumentNullException; import mulan.data.LabelClustering; import mulan.data.MultiLabelInstances; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.meta.FilteredClassifier; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.TechnicalInformation; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; /** * A class for learning a classifier according to disjoint label subsets: a multi-label learner * (the Label Powerset by default) is applied to subsets with multiple labels and a single-label learner * is applied to single label subsets. The final classification prediction is determined by combining * labels predicted by all the learned models. Note: the class is not multi-thread safe. <br> * <br> * There is a mechanism for caching and reusing learned classification models. The caching mechanism is * controlled by {@link #useCache} parameter. * * For more information: * Tenenboim, L., Rokach, L. and Shapira, B. (2009). "Multi-label Classification by Analyzing Labels Dependencies" * Proc. ECML/PKDD 2009 Workshop on Learning from Multi-Label Data (MLD'09) * "Tenenboim-Chekina, L., Rokach, L. and Shapira, B. (2010). "Identification of Label Dependencies for Multi-label * Classification". Proc. ICML 2010 Workshop on Learning from Multi-Label Data (MLD'10"); * * @author Lena Chekina (lenat@bgu.ac.il) * @author Vasiloudis Theodoros * @version 30.11.2010 */ public class SubsetLearner extends MultiLabelMetaLearner { /** * Arraylist containing the MultiLabelLearners that we will train and use to make the * predictions */ private ArrayList<MultiLabelLearner> multiLabelLearners; /** * Arraylist containing the FilteredClassifiers that we will train and use to make the * predictions */ private ArrayList<FilteredClassifier> singleLabelLearners; /** Array containing the way the labels will be split */ private int[][] splitOrder; /** Array containing the indices of the labels we are going to remove */ private int[][] absoluteIndicesToRemove; /** Array containing the Remove objects used to remove the labels for each split */ private Remove[] remove; /** Base single-label classifier that will be used for training and predictions */ protected Classifier baseSingleLabelClassifier; /** indication for disabled caching mechanism */ private boolean useCache = false; /** The method used to cluster the labels */ private LabelClustering clusterer = null; /** * HashMaps containing created models - caching mechanism is used, if enabled by setting the * useCache field to true, for GreedyLabelClustering and EnsembleOfSubsetLearners methods run time optimization */ private static HashMap<String, MultiLabelLearner> existingMultiLabelModels = new HashMap<String, MultiLabelLearner>(); private static HashMap<String, FilteredClassifier> existingSingleLabelModels = new HashMap<String, FilteredClassifier>(); private static HashMap<String, Remove> existingRemove = new HashMap<String, Remove>(); /** * Initialize the SubsetLearner with labels subsets partitioning and single label learner. * {@link mulan.classifier.transformation.LabelPowerset} method initialized with the specified * single label learner.will be used as multilabel learner. * * @param labelsSubsets subsets of dependent labels * @param singleLabelClassifier method used for single label classification */ public SubsetLearner(int[][] labelsSubsets, Classifier singleLabelClassifier) { super(new LabelPowerset(singleLabelClassifier)); if (singleLabelClassifier == null) { throw new ArgumentNullException("singleLabelClassifier"); } if (labelsSubsets == null) { throw new ArgumentNullException("labelsSubsets"); } baseSingleLabelClassifier = singleLabelClassifier; splitOrder = labelsSubsets; absoluteIndicesToRemove = new int[splitOrder.length][]; } /** * Initialize the SubsetLearner with labels set partitioning, multilabel and single label * learners. * * @param labelsSubsets subsets of dependent labels * @param multiLabelLearner method used for multilabel classification * @param singleLabelClassifier method used for single label classification */ public SubsetLearner(int[][] labelsSubsets, MultiLabelLearner multiLabelLearner, Classifier singleLabelClassifier) { super(multiLabelLearner); if (singleLabelClassifier == null) { throw new ArgumentNullException("singleLabelClassifier"); } if (labelsSubsets == null) { throw new ArgumentNullException("labelsSubsets"); } baseSingleLabelClassifier = singleLabelClassifier; splitOrder = labelsSubsets; absoluteIndicesToRemove = new int[splitOrder.length][]; } /** * Initialize the SubsetLearner with a label clustering method, multilabel and single label * learners. * * @param clusteringMethod * @param multiLabelLearner method used for multilabel classification * @param singleLabelClassifier method used for single label classification */ public SubsetLearner(LabelClustering clusteringMethod, MultiLabelLearner multiLabelLearner, Classifier singleLabelClassifier) { super(multiLabelLearner); if (clusteringMethod == null) { throw new ArgumentNullException("clusteringMethod"); } if (singleLabelClassifier == null) { throw new ArgumentNullException("singleLabelClassifier"); } baseSingleLabelClassifier = singleLabelClassifier; clusterer = clusteringMethod; } /** * Reset the label set partitioning. * * @param labelsSubsets - new label set partitioning */ public void resetSubsets(int[][] labelsSubsets) { splitOrder = labelsSubsets; absoluteIndicesToRemove = new int[splitOrder.length][]; } /** * We get the initial dataset through trainingSet. Then for each subset of labels as specified * by labelsSubsets we remove the unneeded labels and train the classifiers using * MultiLabelLearner for multi-label splits and BinaryRelevance approach for single label * splits. Each classification model constructed on a certain training data for a certain labels * subset along with related Remove object is stored in HashMap and can be reused when is needed * next time. * * @param trainingSet The initial {@link mulan.data.MultiLabelInstances} dataset * @throws Exception */ @Override protected void buildInternal(MultiLabelInstances trainingSet) throws Exception { if (clusterer != null) { splitOrder = clusterer.determineClusters(trainingSet); absoluteIndicesToRemove = new int[splitOrder.length][]; } remove = new Remove[splitOrder.length]; prepareIndicesToRemove(); multiLabelLearners = new ArrayList<MultiLabelLearner>(); // Create the lists which will contain the learners singleLabelLearners = new ArrayList<FilteredClassifier>(); int countSingle = 0, countMulti = 0; for (int totalSplitNo = 0; totalSplitNo < splitOrder.length; totalSplitNo++) { // Ensure ascending order of label indexes in the subset Arrays.sort(splitOrder[totalSplitNo]); int foldHash = trainingSet.getDataSet().toString().hashCode(); // create unique key of the trainingSet and the labels subset to be used for caching String modelKey = createKey(splitOrder[totalSplitNo], foldHash); if (splitOrder[totalSplitNo].length > 1) { buildMultiLabelModel(trainingSet, countMulti, totalSplitNo, modelKey); countMulti++; } else { buildSingleLabelModel(trainingSet, countSingle, totalSplitNo, modelKey); countSingle++; } } } /** * Get values into absoluteIndicesToRemove */ private void prepareIndicesToRemove() { int numofSplits = splitOrder.length; // Number of sets the main is going to be split into for (int r = 0; r < splitOrder.length; r++) { // Initialization required to avoid NullPointer exception absoluteIndicesToRemove[r] = new int[numLabels - splitOrder[r].length]; } boolean[][] Selected = new boolean[splitOrder.length][numLabels]; // Initialize an array containing which labels we want for (int i = 0; i < numofSplits; i++) { // Set true for the labels we need to keep for (int j = 0; j < splitOrder[i].length; j++) { Selected[i][splitOrder[i][j]] = true; } } for (int i = 0; i < numofSplits; i++) { // Get the labels we need to KEEP int k = 0; for (int j = 0; j < numLabels; j++) { if (!Selected[i][j]) { absoluteIndicesToRemove[i][k] = labelIndices[j]; k++; } } } } /** * Construct multilabel model. * * @param trainingSet The initial {@link mulan.data.MultiLabelInstances} dataset * @param countMulti the number of previous multilabel splits within the label-set partition * @param totalSplitNo the total number of previous splits within the label-set partition * @param modelKey the unique key of the trainingSet and the labels subset * @throws Exception */ private void buildMultiLabelModel(MultiLabelInstances trainingSet, int countMulti, int totalSplitNo, String modelKey) throws Exception { if (useCache && existingMultiLabelModels.containsKey(modelKey)) { // try to get existing model from cache MultiLabelLearner model = existingMultiLabelModels.get(modelKey); resetRandomSeed(model); // reset random seed of the classifier to it's initial value, such that it will be // equal to that if the classifier was just trained. multiLabelLearners.add(model.makeCopy()); remove[totalSplitNo] = existingRemove.get(modelKey); } else { // (there is no such model in cache) -> build it Instances trainSubset = trainingSet.getDataSet(); remove[totalSplitNo] = new Remove(); // Remove the unneeded labels remove[totalSplitNo].setAttributeIndicesArray(absoluteIndicesToRemove[totalSplitNo]); remove[totalSplitNo].setInputFormat(trainSubset); remove[totalSplitNo].setInvertSelection(false); trainSubset = Filter.useFilter(trainSubset, remove[totalSplitNo]); multiLabelLearners.add(baseLearner.makeCopy()); // Reintegrate dataset and train learner multiLabelLearners.get(countMulti).build( trainingSet.reintegrateModifiedDataSet(trainSubset)); if (useCache) { // add trained model and related Remove object to cache existingMultiLabelModels.put(modelKey, multiLabelLearners.get(countMulti)); existingRemove.put(modelKey, remove[totalSplitNo]); } } } /** * Construct single label model. * * @param trainingSet The initial {@link mulan.data.MultiLabelInstances} dataset * @param countSingle the number of previous single-label splits within the label-set partition * @param totalSplitNo the total number of previous splits within the label-set partition * @param modelKey the unique key of the trainingSet and the labels subset * @throws Exception */ private void buildSingleLabelModel(MultiLabelInstances trainingSet, int countSingle, int totalSplitNo, String modelKey) throws Exception { if (useCache && existingSingleLabelModels.containsKey(modelKey)) { // if single-label model is in cache -> get it FilteredClassifier model = existingSingleLabelModels.get(modelKey); Classifier classifier = model.getClassifier(); // reset random seed of the classifier to it's initial value, such that it will be equal to that if the classifier was just trained resetRandomSeed(classifier); singleLabelLearners.add(model); remove[totalSplitNo] = existingRemove.get(modelKey); } else { // the model is not in cache -> build the model and add it to cache singleLabelLearners.add(new FilteredClassifier()); // Initialize the FilteredClassifiers singleLabelLearners.get(countSingle).setClassifier( AbstractClassifier.makeCopy(baseSingleLabelClassifier)); Instances trainSubset = trainingSet.getDataSet(); remove[totalSplitNo] = new Remove(); // Set the remove filter for the FilteredClassifiers remove[totalSplitNo].setAttributeIndicesArray(absoluteIndicesToRemove[totalSplitNo]); remove[totalSplitNo].setInputFormat(trainSubset); remove[totalSplitNo].setInvertSelection(false); singleLabelLearners.get(countSingle).setFilter(remove[totalSplitNo]); // Set the remaining label as the class index trainSubset.setClassIndex(labelIndices[splitOrder[totalSplitNo][0]]); singleLabelLearners.get(countSingle).buildClassifier(trainSubset); // train if (useCache) { // add trained model and related Remove object to cache existingSingleLabelModels.put(modelKey, singleLabelLearners.get(countSingle)); existingRemove.put(modelKey, remove[totalSplitNo]); } } } /** * Concatenate all integers from an array with additional integer into a single string. * * @param set an array representing labels subset * @param fold a hash code of the current training set * @return a string in the form: "_l1_l2_ ... ln_fold" */ private String createKey(int[] set, int fold) { StringBuilder sb = new StringBuilder("_"); for (int i : set) { sb.append(i); sb.append("_"); } sb.append(fold); return sb.toString(); } /** * Invokes the setSeed(1) or setRandomSeed(1) method of the supplied object's Class, if such * method exist. * * @param model which random seed should be reset. */ public void resetRandomSeed(Object model) { Class aClass = model.getClass(); Method method = null; try { method = aClass.getMethod("setSeed", int.class); } catch (NoSuchMethodException e) { try { method = aClass.getMethod("setRandomSeed", int.class); } catch (NoSuchMethodException e2) { debug("NoSuchMethodExceptions: " + e.getMessage() + " and " + e2.getMessage()); } } try { if (method != null) method.invoke(model, 1); } catch (IllegalAccessException e) { debug("IllegalAccessException: " + e.getMessage()); } catch (InvocationTargetException e) { debug("InvocationTargetException: " + e.getMessage()); } } /** * Set random seed of all internal Learners to 1. */ public void setSeed() { for (MultiLabelLearner learner : multiLabelLearners) { resetRandomSeed(learner); } for (FilteredClassifier learner : singleLabelLearners) { resetRandomSeed(learner); } } /** * We make a prediction using a different method depending on whether the split has one or more * labels * * @param instance the instance for classification prediction * @return the {@link mulan.classifier.MultiLabelOutput} classification prediction for the * instance * @throws Exception */ public MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { MultiLabelOutput[] MLO = new MultiLabelOutput[splitOrder.length]; int singleSplitNo = 0, multiSplitNo = 0; boolean[][] BooleanSubsets = new boolean[splitOrder.length][]; double[][] ConfidenceSubsets = new double[splitOrder.length][]; for (int r = 0; r < splitOrder.length; r++) { // Initilization required to avoid NullPointer exception BooleanSubsets[r] = new boolean[splitOrder[r].length]; ConfidenceSubsets[r] = new double[splitOrder[r].length]; } boolean[] BipartitionOut = new boolean[numLabels]; double[] ConfidenceOut = new double[numLabels]; // Make a prediction for the instance in each separate dataset // The learners have been trained for each seperate dataset in buildInternal for (int i = 0; i < splitOrder.length; i++) { if (splitOrder[i].length == 1) { // Prediction for single label splits double distribution[]; try { distribution = singleLabelLearners.get(singleSplitNo).distributionForInstance( instance); } catch (Exception e) { System.out.println(e); return null; } int maxIndex = (distribution[0] > distribution[1]) ? 0 : 1; // Ensure correct predictions both for class values {0,1} and {1,0} Attribute classAttribute = singleLabelLearners.get(singleSplitNo).getFilter() .getOutputFormat().classAttribute(); BooleanSubsets[i][0] = (classAttribute.value(maxIndex).equals("1")); // The confidence of the label being equal to 1 ConfidenceSubsets[i][0] = distribution[classAttribute.indexOfValue("1")]; singleSplitNo++; } else { // Prediction for multi label splits remove[i].input(instance); remove[i].batchFinished(); Instance newInstance = remove[i].output(); MLO[multiSplitNo] = multiLabelLearners.get(multiSplitNo) .makePrediction(newInstance); // Get each array of Bipartitions, confidences from each learner BooleanSubsets[i] = MLO[multiSplitNo].getBipartition(); ConfidenceSubsets[i] = MLO[multiSplitNo].getConfidences(); multiSplitNo++; } } // Concatenate the outputs while putting everything in its right place for (int i = 0; i < splitOrder.length; i++) { for (int j = 0; j < splitOrder[i].length; j++) { BipartitionOut[splitOrder[i][j]] = BooleanSubsets[i][j]; ConfidenceOut[splitOrder[i][j]] = ConfidenceSubsets[i][j]; } } return new MultiLabelOutput(BipartitionOut, ConfidenceOut); } public void setUseCache(boolean useCache) { this.useCache = useCache; } /** * Returns an instance of a TechnicalInformation object, containing detailed information about * the technical background of this class, e.g., paper reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS); result.setValue(TechnicalInformation.Field.AUTHOR, "Lena Tenenboim, Lior Rokach, and Bracha Shapira"); result.setValue(TechnicalInformation.Field.TITLE, "Multi-label Classification by Analyzing Labels Dependencies"); result.setValue(TechnicalInformation.Field.VOLUME, "Proc. ECML/PKDD 2009 Workshop on Learning from Multi-Label Data (MLD'09)"); result.setValue(TechnicalInformation.Field.YEAR, "2009"); result.setValue(TechnicalInformation.Field.PAGES, "117--132"); result.setValue(TechnicalInformation.Field.ADDRESS, "Bled, Slovenia"); result.setValue(TechnicalInformation.Field.AUTHOR, "Lena Tenenboim-Chekina, Lior Rokach, and Bracha Shapira"); result.setValue(TechnicalInformation.Field.TITLE, "Identification of Label Dependencies for Multi-label Classification"); result.setValue(TechnicalInformation.Field.VOLUME, "Proc. ICML 2010 Workshop on Learning from Multi-Label Data (MLD'10"); result.setValue(TechnicalInformation.Field.YEAR, "2010"); result.setValue(TechnicalInformation.Field.PAGES, "53--60"); result.setValue(TechnicalInformation.Field.ADDRESS, "Haifa, Israel"); return result; } public String getModel() { String out = ""; for (int i = 0; i < multiLabelLearners.size(); i++) { out += ((LabelPowerset) multiLabelLearners.get(i)).getBaseClassifier().toString(); } return out; } }