/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LabelsetPruning.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.transformation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import mulan.data.LabelSet;
import mulan.data.MultiLabelInstances;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
/**
* Common functionality class for the PPT and PS algorithms <p>
*
* @author Grigorios Tsoumakas
* @version June 4, 2010
*/
public abstract class LabelsetPruning extends LabelPowerset {
/** labelsets and a list with the corresponding instances */
HashMap<LabelSet, ArrayList<Instance>> ListInstancePerLabel;
/** parameter for the threshold of number of occurences of a labelset */
protected int p;
/** format of the data */
Instances format;
/**
* Constructor that initializes learner with base algorithm and main parameter
*
* @param classifier base single-label classification algorithm
* @param aP number of instances required for a labelset to be included.
*/
public LabelsetPruning(Classifier classifier, int aP) {
super(classifier);
if (aP <= 0) {
throw new IllegalArgumentException("p should be larger than 0!");
}
p = aP;
setConfidenceCalculationMethod(2);
setMakePredictionsBasedOnConfidences(true);
threshold = 0.21;
}
abstract ArrayList<Instance> processRejected(LabelSet ls);
@Override
protected void buildInternal(MultiLabelInstances mlDataSet) throws Exception {
Instances data = mlDataSet.getDataSet();
format = new Instances(data, 0);
int numInstances = data.numInstances();
ListInstancePerLabel = new HashMap<LabelSet, ArrayList<Instance>>();
for (int i = 0; i < numInstances; i++) {
double[] dblLabels = new double[numLabels];
for (int j = 0; j < numLabels; j++) {
int index = labelIndices[j];
double value = Double.parseDouble(data.attribute(index).value((int) data.instance(i).value(index)));
dblLabels[j] = value;
}
LabelSet labelSet = new LabelSet(dblLabels);
if (ListInstancePerLabel.containsKey(labelSet)) {
ListInstancePerLabel.get(labelSet).add(data.instance(i));
} else {
ArrayList<Instance> li = new ArrayList<Instance>();
li.add(data.instance(i));
ListInstancePerLabel.put(labelSet, li);
}
}
// Iterates the structure and a) if occurences of a labelset are higher
// than p parameter then add them to the training set, b) if occurences
// are less, then depending on the strategy discard/reintroduce them
Instances newData = new Instances(data, 0);
Iterator<LabelSet> it = ListInstancePerLabel.keySet().iterator();
while (it.hasNext()) {
LabelSet ls = it.next();
ArrayList<Instance> instances = ListInstancePerLabel.get(ls);
if (instances.size() > p) {
for (int i = 0; i < instances.size(); i++) {
newData.add(instances.get(i));
}
} else {
ArrayList<Instance> processed = processRejected(ls);
newData.addAll(processed);
}
}
super.buildInternal(new MultiLabelInstances(newData, mlDataSet.getLabelsMetaData()));
}
}