/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MetaLabeler.java * Copyright (C) 2009 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.meta.thresholding; import java.util.ArrayList; import java.util.Set; import java.util.TreeSet; import java.util.logging.Level; import java.util.logging.Logger; import mulan.classifier.MultiLabelLearner; import mulan.classifier.MultiLabelOutput; import mulan.data.DataUtils; import mulan.data.MultiLabelInstances; import mulan.transformations.RemoveAllLabels; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; /** * * @author Marios Ioannou * @author George Sakkas * @author Grigorios Tsoumakas * @version 2010.12.14 */ public class MetaLabeler extends Meta { /** the type of the class*/ private String classChoice; /** * Constructor that initializes the learner * * @param baseLearner the underlying multi-label learner * @param classifier the binary classification * @param metaDataChoice the type of meta-data * @param aClassChoice the type of the class */ public MetaLabeler(MultiLabelLearner baseLearner, Classifier classifier, String metaDataChoice, String aClassChoice) { super(baseLearner, classifier, metaDataChoice); if (!metaDataChoice.equals("Content-Based")) { try { foldLearner = baseLearner.makeCopy(); } catch (Exception ex) { Logger.getLogger(MetaLabeler.class.getName()).log(Level.SEVERE, null, ex); } kFoldsCV = 3; } classChoice = aClassChoice; } @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Lei Tang and Sugu Rajan and Yijay K. Narayanan"); result.setValue(Field.TITLE, "Large scale multi-label classification via metalabeler"); result.setValue(Field.BOOKTITLE, "Proceedings of the 18th international conference on World wide web "); result.setValue(Field.PAGES, "211-220"); result.setValue(Field.LOCATION, "Madrid, Spain"); result.setValue(Field.YEAR, "2009"); return result; } @Override protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { //System.out.println(instance); MultiLabelOutput mlo = baseLearner.makePrediction(instance); int[] arrayOfRankink = new int[numLabels]; boolean[] predictedLabels = new boolean[numLabels]; Instance modifiedIns = modifiedInstanceX(instance, metaDatasetChoice); //System.out.println(modifiedIns); modifiedIns.insertAttributeAt(modifiedIns.numAttributes()); // set dataset to instance modifiedIns.setDataset(classifierInstances); //get the bipartition_key after classify the instance int bipartition_key; if (classChoice.compareTo("Nominal-Class") == 0) { double classify_key = classifier.classifyInstance(modifiedIns); String s = classifierInstances.attribute(classifierInstances.numAttributes() - 1).value((int) classify_key); bipartition_key = Integer.valueOf(s); } else { //Numeric-Class double classify_key = classifier.classifyInstance(modifiedIns); bipartition_key = (int) Math.round(classify_key); } if (mlo.hasRanking()) { arrayOfRankink = mlo.getRanking(); for (int i = 0; i < numLabels; i++) { if (arrayOfRankink[i] <= bipartition_key) { predictedLabels[i] = true; } else { predictedLabels[i] = false; } } } MultiLabelOutput final_mlo = new MultiLabelOutput(predictedLabels, mlo.getConfidences()); return final_mlo; } private int countTrueLabels(Instance instance) { int numTrueLabels = 0; for (int i = 0; i < numLabels; i++) { int labelIndice = labelIndices[i]; if (instance.dataset().attribute(labelIndice).value((int) instance.value(labelIndice)).equals("1")) { numTrueLabels++; } } return numTrueLabels; } protected Instances transformData(MultiLabelInstances trainingData) throws Exception { // initialize classifier instances classifierInstances = RemoveAllLabels.transformInstances(trainingData); classifierInstances = new Instances(classifierInstances, 0); Attribute target = null; if (classChoice.equals("Nominal-Class")) { int countTrueLabels = 0; Set<Integer> treeSet = new TreeSet(); for (int instanceIndex = 0; instanceIndex < trainingData.getDataSet().numInstances(); instanceIndex++) { countTrueLabels = 0; for (int i = 0; i < numLabels; i++) { int labelIndice = labelIndices[i]; if (trainingData.getDataSet().attribute(labelIndice).value((int) trainingData.getDataSet().instance(instanceIndex).value(labelIndice)).equals("1")) { countTrueLabels++; } } treeSet.add(countTrueLabels); } ArrayList<String> classlabel = new ArrayList<String>(); for (Integer x : treeSet) { classlabel.add(x.toString()); } target = new Attribute("Class", classlabel); } else if (classChoice.equals("Numeric-Class")) { target = new Attribute("Class"); } classifierInstances.insertAttributeAt(target, classifierInstances.numAttributes()); classifierInstances.setClassIndex(classifierInstances.numAttributes() - 1); // create instances if (metaDatasetChoice.equals("Content-Based")) { for (int instanceIndex = 0; instanceIndex < trainingData.getNumInstances(); instanceIndex++) { Instance instance = trainingData.getDataSet().instance(instanceIndex); double[] values = instance.toDoubleArray(); double[] newValues = new double[classifierInstances.numAttributes()]; for (int i = 0; i < featureIndices.length; i++) { newValues[i] = values[featureIndices[i]]; } //set the number of true labels of an instance int numTrueLabels = countTrueLabels(instance); if (classChoice.compareTo("Nominal-Class") == 0) { newValues[newValues.length - 1] = classifierInstances.attribute(classifierInstances.numAttributes() - 1).indexOfValue("" + numTrueLabels); } else if (classChoice.compareTo("Numeric-Class") == 0) { newValues[newValues.length - 1] = numTrueLabels; } Instance newInstance = DataUtils.createInstance(instance, instance.weight(), newValues); classifierInstances.add(newInstance); } } else { for (int k = 0; k < kFoldsCV; k++) { //Split data to train and test sets MultiLabelLearner tempLearner; MultiLabelInstances mlTest; if (kFoldsCV == 1) { tempLearner = baseLearner; mlTest = trainingData; } else { Instances train = trainingData.getDataSet().trainCV(kFoldsCV, k); Instances test = trainingData.getDataSet().testCV(kFoldsCV, k); MultiLabelInstances mlTrain = new MultiLabelInstances(train, trainingData.getLabelsMetaData()); mlTest = new MultiLabelInstances(test, trainingData.getLabelsMetaData()); tempLearner = foldLearner.makeCopy(); tempLearner.build(mlTrain); } // copy features and labels, set metalabels for (int instanceIndex = 0; instanceIndex < mlTest.getDataSet().numInstances(); instanceIndex++) { Instance instance = mlTest.getDataSet().instance(instanceIndex); // initialize new class values double[] newValues = new double[classifierInstances.numAttributes()]; // create features valuesX(tempLearner, instance, newValues, metaDatasetChoice); //set the number of true labels of an instance int numTrueLabels = countTrueLabels(instance); if (classChoice.compareTo("Nominal-Class") == 0) { newValues[newValues.length - 1] = classifierInstances.attribute(classifierInstances.numAttributes() - 1).indexOfValue("" + numTrueLabels); } else if (classChoice.compareTo("Numeric-Class") == 0) { newValues[newValues.length - 1] = numTrueLabels; } // add the new instance to classifierInstances Instance newInstance = DataUtils.createInstance(mlTest.getDataSet().instance(instanceIndex), mlTest.getDataSet().instance(instanceIndex).weight(), newValues); classifierInstances.add(newInstance); } } } return classifierInstances; } /** * Sets the number of folds for internal cv * * @param f the number of folds */ public void setFolds(int f) { kFoldsCV = f; } }