/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* HMC.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.meta;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.InvalidDataFormatException;
import mulan.data.LabelNode;
import mulan.data.LabelNodeImpl;
import mulan.data.LabelsMetaData;
import mulan.data.LabelsMetaDataImpl;
import mulan.data.MultiLabelInstances;
import mulan.data.DataUtils;
import mulan.transformations.RemoveAllLabels;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* Class that implements a Hierarchical Multilabel classifier (HMC).
* HMC classifier takes as parameter any kind of multilabel classifier and
* builds a hierarchy. Any node of hierarchy is a classifier and is trained
* separately. The root classifier is trained on all data and as getting down
* the hierarchy tree the data is adjusted properly to each node. Firstly,
* instances that do not belong to the node are removed and then attributes that
* are unnecessary are removed also.
*
* @author George Saridis
* @author Grigorios Tsoumakas
* @version 0.2
*/
public class HMC extends MultiLabelMetaLearner {
private LabelsMetaData originalMetaData;
private HMCNode root;
private Map<String, Integer> labelsAndIndices;
private long NoNodes = 0;
private long NoClassifierEvals = 0;
private long TotalUsedTrainInsts = 0;
public HMC(MultiLabelLearner baseLearner) throws Exception {
super(baseLearner);
}
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Grigorios Tsoumakas and Ioannis Katakis and Ioannis Vlahavas");
result.setValue(Field.TITLE, "Effective and Efficient Multilabel Classification in Domains with Large Number of Labels");
result.setValue(Field.BOOKTITLE, "Proc. ECML/PKDD 2008 Workshop on Mining Multidimensional Data (MMD'08)");
result.setValue(Field.LOCATION, "Antwerp, Belgium");
result.setValue(Field.YEAR, "2008");
return result;
}
private void buildRec(HMCNode node, Instances data) throws InvalidDataFormatException, Exception {
String metaLabel = node.getName();
//debug("Preparing node data");
Set<String> childrenLabels = new HashSet<String>();
Set<String> currentlyAvailableLabels = new HashSet<String>();
if (metaLabel.equals("root")) {
for (LabelNode child : originalMetaData.getRootLabels()) {
childrenLabels.add(child.getName());
}
currentlyAvailableLabels = originalMetaData.getLabelNames();
} else {
LabelNode labelNode = originalMetaData.getLabelNode(metaLabel);
for (LabelNode child : labelNode.getChildren()) {
childrenLabels.add(child.getName());
}
currentlyAvailableLabels = labelNode.getDescendantLabels();
}
// delete non-children labels
Set<String> labelsToDelete = new HashSet(currentlyAvailableLabels);
labelsToDelete.removeAll(childrenLabels);
//System.out.println("Children: " + Arrays.toString(childrenLabels.toArray()));
//System.out.println("Labels to delete:" + Arrays.toString(labelsToDelete.toArray()));
int[] indicesToDelete = new int[labelsToDelete.size()];
int counter1 = 0;
for (String label : labelsToDelete) {
indicesToDelete[counter1] = data.attribute(label).index();
counter1++;
}
Remove filter1 = new Remove();
filter1.setAttributeIndicesArray(indicesToDelete);
filter1.setInputFormat(data);
Instances nodeInstances = Filter.useFilter(data, filter1);
// System.out.println()
// create meta data
LabelsMetaDataImpl nodeMetaData = new LabelsMetaDataImpl();
for (String label : childrenLabels) {
nodeMetaData.addRootNode(new LabelNodeImpl(label));
}
// create multi-label instance
MultiLabelInstances nodeData = new MultiLabelInstances(nodeInstances, nodeMetaData);
//debug("Building model");
node.build(nodeData);
//debug("spark #instances:"+nodeInstances.numInstances());
TotalUsedTrainInsts += nodeInstances.numInstances();
NoNodes++;
//debug("spark:#nodes: "+ HMCNoNodes);
for (String childLabel : childrenLabels) {
LabelNode childNode = originalMetaData.getLabelNode(childLabel);
if (!childNode.hasChildren()) {
continue;
}
//debug("Preparing child data");
// remove instances where child is 0
int childMetaLabelIndex = data.attribute(childLabel).index();
Instances childData = new Instances(data);
for (int i = 0; i < childData.numInstances(); i++) {
if (childData.instance(i).stringValue(childMetaLabelIndex).equals("0")) {
childData.delete(i);
// While deleting an instance from the trainSet, i must reduced too
i--;
}
}
// delete non-descendant labels
Set<String> descendantLabels = childNode.getDescendantLabels();
Set<String> labelsToDelete2 = new HashSet(currentlyAvailableLabels);
labelsToDelete2.removeAll(descendantLabels);
//System.out.println("Labels to delete:" + Arrays.toString(labelsToDelete2.toArray()));
int[] indicesToDelete2 = new int[labelsToDelete2.size()];
int counter2 = 0;
for (String label : labelsToDelete2) {
indicesToDelete2[counter2] = childData.attribute(label).index();
counter2++;
}
Remove filter2 = new Remove();
filter2.setAttributeIndicesArray(indicesToDelete2);
filter2.setInputFormat(childData);
childData = Filter.useFilter(childData, filter2);
MultiLabelLearner mll = baseLearner.makeCopy();
HMCNode child = new HMCNode(childLabel, mll);
node.addChild(child);
buildRec(child, childData);
}
}
@Override
protected void buildInternal(MultiLabelInstances dataSet) throws Exception {
originalMetaData = dataSet.getLabelsMetaData();
Set<String> rootLabels = new HashSet<String>();
for (LabelNode node : originalMetaData.getRootLabels()) {
rootLabels.add(node.getName());
}
MultiLabelLearner mll = baseLearner.makeCopy();
root = new HMCNode("root", mll);
buildRec(root, dataSet.getDataSet());
labelsAndIndices = dataSet.getLabelsOrder();
}
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
boolean[] predictedLabels = new boolean[numLabels];
double[] confidences = new double[numLabels];
makePrediction(root, instance, predictedLabels, confidences);
return new MultiLabelOutput(predictedLabels, confidences);
}
private void makePrediction(HMCNode currentNode, Instance instance, boolean[] predictedLabels, double[] confidences) throws Exception {
//System.out.println("Node: " + currentNode.getName());
double[] values = instance.toDoubleArray();
Instance transformed = DataUtils.createInstance(instance, 1, values);
// delete all labels apart from those of current node
int[] currentNodeLabelIndices = currentNode.getLabelIndices();
Set<Integer> indicesToKeep = new HashSet<Integer>();
for (int i = 0; i < currentNodeLabelIndices.length; i++) {
String labelToKeep = currentNode.getHeader().attribute(currentNodeLabelIndices[i]).name();
indicesToKeep.add(labelIndices[labelsAndIndices.get(labelToKeep)]);
}
if (labelIndices.length - indicesToKeep.size() != 0) {
int[] indicesToDelete = new int[labelIndices.length - indicesToKeep.size()];
int counter = 0;
for (int i = 0; i < labelIndices.length; i++) {
if (indicesToKeep.contains(labelIndices[i])) {
continue;
}
indicesToDelete[counter] = labelIndices[i];
counter++;
}
transformed = RemoveAllLabels.transformInstance(transformed, indicesToDelete);
}
transformed.setDataset(currentNode.getHeader());
// add as many attributes as the children
// System.out.println("header:" + currentNode.getHeader());
//System.out.println(transformed.toString());
//debug("working at node " + currentNode.getName());
//debug(Arrays.toString(predictedLabels));
NoClassifierEvals++;
MultiLabelOutput pred = currentNode.makePrediction(transformed);
int[] indices = currentNode.getLabelIndices();
boolean[] temp = pred.getBipartition();
for (int i = 0; i < temp.length; i++) {
String childName = currentNode.getHeader().attribute(indices[i]).name();
//System.out.println("childName:" + childName);
int idx = labelsAndIndices.get(childName);
if (pred.getBipartition()[i] == true) {
predictedLabels[idx] = true;
confidences[idx] = pred.getConfidences()[i];
if (currentNode.hasChildren()) {
for (HMCNode child : currentNode.getChildren()) {
if (child.getName().equals(childName)) {
makePrediction(child, instance, predictedLabels, confidences);
}
}
}
} else {
predictedLabels[idx] = false;
Set<String> descendantLabels = originalMetaData.getLabelNode(childName).getDescendantLabels();
if (descendantLabels != null) {
for (String label : descendantLabels) {
int idx2 = labelsAndIndices.get(label);
predictedLabels[idx2] = false;
confidences[idx2] = pred.getConfidences()[i];
}
}
}
}
}
/**
* Deletes the unnecessary attributes. Actually keeps only the children
* names of the node that is going to be trained as attributes and deletes
* the rest.
*
* @param mlData the instances from which the attributes will be removed
* @param currentLabel the name of the node whose children will be kept as attributes
* @return MultiLabelInstances
* @throws mulan.data.InvalidDataFormatException
*/
protected MultiLabelInstances deleteLabels(MultiLabelInstances mlData, String currentLabel, boolean keepSubTree) throws InvalidDataFormatException {
LabelsMetaData currentMetaData = mlData.getLabelsMetaData();
LabelNodeImpl currentLabelNode = (LabelNodeImpl) currentMetaData.getLabelNode(currentLabel);
Set<String> labelsToKeep;
Set<String> allLabels = mlData.getLabelsMetaData().getLabelNames();
LabelsMetaDataImpl labelsMetaData = new LabelsMetaDataImpl();
//Prepare the appropriate labelsMetaData
if (keepSubTree) {
labelsToKeep = currentLabelNode.getDescendantLabels();
for (String rootLabel : currentLabelNode.getChildrenLabels()) {
LabelNodeImpl rootNode = new LabelNodeImpl(rootLabel);
if (mlData.getLabelsMetaData().getLabelNode(rootLabel).hasChildren()) {
append(rootNode, mlData.getLabelsMetaData());
}
labelsMetaData.addRootNode(rootNode);
}
} else {
labelsToKeep = currentLabelNode.getChildrenLabels();
for (String rootLabel : labelsToKeep) {
LabelNodeImpl rootNode = new LabelNodeImpl(rootLabel);
labelsMetaData.addRootNode(rootNode);
}
}
//debug("Labels: " + labelsMetaData.getLabelNames().toString());
//Deleting labels from instances
for (String label : allLabels) {
if (!labelsToKeep.contains(label)) {
int idx = mlData.getDataSet().attribute(label).index();
mlData.getDataSet().deleteAttributeAt(idx);
}
}
return new MultiLabelInstances(mlData.getDataSet(), labelsMetaData);
}
private void append(LabelNodeImpl labelNode, LabelsMetaData labelsMetaData) {
LabelNode father = labelsMetaData.getLabelNode(labelNode.getName());
for (LabelNode child : father.getChildren()) {
LabelNodeImpl newLabelNode = new LabelNodeImpl(child.getName());
if (child.hasChildren()) {
append(newLabelNode, labelsMetaData);
}
labelNode.addChildNode(newLabelNode);
}
}
/**
* Deletes the unnecessary instances, the instances that have value 0 on
* given attribute.
*
* @param trainSet the trainSet on which the deletion will be applied
* @param attrIndex the index of the attribute that the deletion is based
*/
protected void deleteInstances(Instances trainSet, int attrIndex) {
for (int i = 0; i < trainSet.numInstances(); i++) {
if (trainSet.instance(i).stringValue(attrIndex).equals("0")) {
trainSet.delete(i);
// While deleting an instance from the trainSet, i must reduced too
i--;
}
}
}
//spark temporary edit
public long getNoNodes() {
return NoNodes;
}
public long getNoClassifierEvals() {
return NoClassifierEvals;
}
public long getTotalUsedTrainInsts() {
return TotalUsedTrainInsts;
}
}