/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * IncludeLabelsClassifier.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.transformation; import mulan.classifier.*; import mulan.data.MultiLabelInstances; import mulan.transformations.PT6Transformation; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; /** * A multilabel classifier based on Problem Transformation 6. * The multiple label attributes are mapped to two attributes: * a) one nominal attribute containing the class * b) one binary attribute containing whether it is true. * * @author Robert Friberg * @author Grigorios Tsoumakas * @version $Revision: 0.04 $ */ public class IncludeLabelsClassifier extends TransformationBasedMultiLabelLearner { /** * The transformation used by the classifier */ private PT6Transformation pt6Trans; /** * A dataset with the format needed by the base classifier. * It is potentially expensive copying datasets with many attributes, * so it is used for building the classifier and then it's mlData * are discarded and it is reused during prediction. */ protected Instances transformed; /** * Constructor that initializes a new learner with the given base classifier * * @param classifier */ public IncludeLabelsClassifier(Classifier classifier) { super(classifier); } @Override public void buildInternal(MultiLabelInstances mlData) throws Exception { //Do the transformation //and generate the classifier pt6Trans = new PT6Transformation(); debug("Transforming the dataset"); transformed = pt6Trans.transformInstances(mlData); debug("Building the base-level classifier"); baseClassifier.buildClassifier(transformed); transformed.delete(); } protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { double[] confidences = new double[numLabels]; boolean[] bipartition = new boolean[numLabels]; Instance newInstance = pt6Trans.transformInstance(instance); //calculate confidences //debug(instance.toString()); for (int i = 0; i < numLabels; i++) { newInstance.setDataset(transformed); newInstance.setValue(newInstance.numAttributes() - 2, instance.dataset().attribute(labelIndices[i]).name()); //debug(newInstance.toString()); double[] temp = baseClassifier.distributionForInstance(newInstance); //debug(temp.toString()); confidences[i] = temp[transformed.classAttribute().indexOfValue("1")]; //debug("" + confidences[i]); bipartition[i] = temp[transformed.classAttribute().indexOfValue("1")] >= temp[transformed.classAttribute().indexOfValue("0")] ? true : false; //debug("" + bipartition[i]); } MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidences); return mlo; } }