/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * HierarchyBuilder.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.meta; import java.io.File; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.transform.OutputKeys; import javax.xml.transform.Source; import javax.xml.transform.Transformer; import javax.xml.transform.TransformerFactory; import javax.xml.transform.dom.DOMSource; import javax.xml.transform.stream.StreamResult; import mulan.data.InvalidDataFormatException; import mulan.data.LabelNode; import mulan.data.LabelNodeImpl; import mulan.data.LabelsMetaData; import mulan.data.LabelsMetaDataImpl; import mulan.data.MultiLabelInstances; import mulan.data.DataUtils; import org.w3c.dom.Document; import org.w3c.dom.Element; import weka.clusterers.EM; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ArffSaver; /** * Class that builds a hierarchy on flat lables of given mulltilabel data. * The hierarchy may be built with three methods. * * @author George Saridis * @author Grigorios Tsoumakas * @version 0.1 */ public class HierarchyBuilder implements Serializable { private int numPartitions; private Document labelsXMLDoc; private Method method; public HierarchyBuilder(int partitions, Method method) { numPartitions = partitions; this.method = method; } /** * Builds a hierarhical multi-label dataset. Firstly a random hierarchy is * built on top of the labels of a flat multi-label dataset, by recursively * randomly partitioning the labels into a specified number of clusters. * Then the values for the new "meta-labels" are properly set, so that * the hierarchy is respected. * * @param mlData the multiLabel data on which the new hierarchy will be built * @return the new multiLabel data * @throws java.lang.Exception */ public MultiLabelInstances buildHierarchy(MultiLabelInstances mlData) throws Exception { LabelsMetaData labelsMetaData = buildLabelHierarchy(mlData); return HierarchyBuilder.createHierarchicalDataset(mlData, labelsMetaData); } /** * Builds a hierarhy of labels on top of the labels of a flat multi-label * dataset, by recursively partitioning the labels into a specified number * of partitions. * * @param mlData the multiLabel data on with the new hierarchy will be built * @return a hierarchy of labels * @throws java.lang.Exception */ public LabelsMetaData buildLabelHierarchy(MultiLabelInstances mlData) throws Exception { if (numPartitions > mlData.getNumLabels()) { throw new IllegalArgumentException("Number of labels is smaller than the number of partitions"); } Set<String> setOfLabels = mlData.getLabelsMetaData().getLabelNames(); List<String> listOfLabels = new ArrayList<String>(); for (String label : setOfLabels) { listOfLabels.add(label); } ArrayList<String>[] childrenLabels = null; switch (method) { case Random: childrenLabels = randomPartitioning(numPartitions, listOfLabels); break; case Clustering: childrenLabels = clustering(numPartitions, listOfLabels, mlData, false); break; case BalancedClustering: childrenLabels = clustering(numPartitions, listOfLabels, mlData, true); break; } for (int i = 0; i < numPartitions; i++) { if (childrenLabels[i].size() == listOfLabels.size()) { // another idea is to add leaves here childrenLabels = randomPartitioning(numPartitions, listOfLabels); break; } } LabelsMetaDataImpl metaData = new LabelsMetaDataImpl(); for (int i = 0; i < numPartitions; i++) { if (childrenLabels[i].size() == 0) { continue; } if (childrenLabels[i].size() == 1) { metaData.addRootNode(new LabelNodeImpl(childrenLabels[i].get(0))); continue; } if (childrenLabels[i].size() > 1) { LabelNodeImpl metaLabel = new LabelNodeImpl("MetaLabel " + (i + 1)); createLabelsMetaDataRecursive(metaLabel, childrenLabels[i], mlData); metaData.addRootNode(metaLabel); } } return metaData; } public MultiLabelInstances buildHierarchyAndSaveFiles(MultiLabelInstances mlData, String arffName, String xmlName) throws Exception { MultiLabelInstances newData = buildHierarchy(mlData); saveToArffFile(newData.getDataSet(), new File(arffName)); createXMLFile(mlData.getLabelsMetaData()); saveToXMLFile(xmlName); return newData; } private void createLabelsMetaDataRecursive(LabelNodeImpl node, List<String> labels, MultiLabelInstances mlData) { if (labels.size() <= numPartitions) { for (int i = 0; i < labels.size(); i++) { LabelNodeImpl child = new LabelNodeImpl(labels.get(i)); node.addChildNode(child); } return; } ArrayList<String>[] childrenLabels = null; switch (method) { case Random: childrenLabels = randomPartitioning(numPartitions, labels); break; case Clustering: childrenLabels = clustering(numPartitions, labels, mlData, false); break; case BalancedClustering: childrenLabels = clustering(numPartitions, labels, mlData, true); break; } for (int i = 0; i < numPartitions; i++) { if (childrenLabels[i].size() == labels.size()) { // another idea is to add leaves here childrenLabels = randomPartitioning(numPartitions, labels); break; } } for (int i = 0; i < numPartitions; i++) { if (childrenLabels[i].size() == 0) { continue; } if (childrenLabels[i].size() == 1) { LabelNodeImpl child = new LabelNodeImpl(childrenLabels[i].get(0)); node.addChildNode(child); continue; } if (childrenLabels[i].size() > 1) { LabelNodeImpl child = new LabelNodeImpl(node.getName() + "." + (i + 1)); node.addChildNode(child); createLabelsMetaDataRecursive(child, childrenLabels[i], mlData); } } } private ArrayList<String>[] clustering(int clusters, List<String> labels, MultiLabelInstances mlData, boolean balanced) { ArrayList<String>[] childrenLabels = new ArrayList[clusters]; for (int i = 0; i < clusters; i++) { childrenLabels[i] = new ArrayList<String>(); } // transpose data and keep only labels in the parameter list int numInstances = mlData.getDataSet().numInstances(); ArrayList<Attribute> attInfo = new ArrayList<Attribute>(numInstances); for (int i = 0; i < numInstances; i++) { Attribute att = new Attribute("instance" + (i + 1)); attInfo.add(att); } System.out.println("constructing instances"); Instances transposed = new Instances("transposed", attInfo, 0); for (int i = 0; i < labels.size(); i++) { double[] values = new double[numInstances]; for (int j = 0; j < numInstances; j++) { values[j] = mlData.getDataSet().instance(j).value(mlData.getDataSet().attribute(labels.get(i))); } Instance newInstance = DataUtils.createInstance(mlData.getDataSet().instance(0), 1, values); transposed.add(newInstance); } if (!balanced) { EM clusterer = new EM(); try { // cluster the labels clusterer.setNumClusters(clusters); System.out.println("clustering"); clusterer.buildClusterer(transposed); // return the clustering for (int i = 0; i < labels.size(); i++) { childrenLabels[clusterer.clusterInstance(transposed.instance(i))].add(labels.get(i)); } } catch (Exception ex) { Logger.getLogger(HierarchyBuilder.class.getName()).log(Level.SEVERE, null, ex); } } else { ConstrainedKMeans clusterer = new ConstrainedKMeans(); try { // cluster the labels clusterer.setMaxIterations(20); clusterer.setNumClusters(clusters); System.out.println("balanced clustering"); clusterer.buildClusterer(transposed); // return the clustering for (int i = 0; i < labels.size(); i++) { childrenLabels[clusterer.clusterInstance(transposed.instance(i))].add(labels.get(i)); } } catch (Exception ex) { Logger.getLogger(HierarchyBuilder.class.getName()).log(Level.SEVERE, null, ex); } } return childrenLabels; } private ArrayList<String>[] randomPartitioning(int partitions, List<String> labels) { ArrayList<String>[] childrenLabels = new ArrayList[partitions]; for (int i = 0; i < partitions; i++) { childrenLabels[i] = new ArrayList<String>(); } Random rnd = new Random(); while (!labels.isEmpty()) { for (int i = 0; i < partitions; i++) { if (labels.isEmpty()) { break; } String rndLabel = labels.remove(rnd.nextInt(labels.size())); childrenLabels[i].add(rndLabel); } } return childrenLabels; } /** * Creates the hierarchical dataset according to the original multilabel * instances object and the constructed label hierarchy * * @param mlData the original multilabel instances * @param metaData the metadata of the constructed label hierarchy * @return the produced dataset * @throws InvalidDataFormatException */ public static MultiLabelInstances createHierarchicalDataset(MultiLabelInstances mlData, LabelsMetaData metaData) throws InvalidDataFormatException { Set<String> leafLabels = mlData.getLabelsMetaData().getLabelNames(); Set<String> metaLabels = metaData.getLabelNames(); for (String string : leafLabels) { metaLabels.remove(string); } Instances dataSet = mlData.getDataSet(); int numMetaLabels = metaLabels.size(); // copy existing attributes ArrayList<Attribute> atts = new ArrayList<Attribute>(dataSet.numAttributes() + numMetaLabels); for (int i = 0; i < dataSet.numAttributes(); i++) { atts.add(dataSet.attribute(i)); } ArrayList<String> labelValues = new ArrayList<String> (); labelValues.add("0"); labelValues.add("1"); // add metalabel attributes for (String metaLabel : metaLabels) { atts.add(new Attribute(metaLabel, labelValues)); } // initialize dataset Instances newDataSet = new Instances("hierarchical", atts, dataSet.numInstances()); // copy features and labels, set metalabels for (int i = 0; i < dataSet.numInstances(); i++) { //System.out.println("Constructing instance " + (i+1) + "/" + dataSet.numInstances()); // initialize new values double[] newValues = new double[newDataSet.numAttributes()]; Arrays.fill(newValues, 0); // copy features and labels double[] values = dataSet.instance(i).toDoubleArray(); System.arraycopy(values, 0, newValues, 0, values.length); // set metalabels for (String label : leafLabels) { Attribute att = dataSet.attribute(label); if (att.value((int) dataSet.instance(i).value(att)).equals("1")) { //System.out.println(label); //System.out.println(Arrays.toString(metaData.getLabelNames().toArray())); LabelNode currentNode = metaData.getLabelNode(label); // put 1 all the way up to the root, unless you see a 1, in which case stop while (currentNode.hasParent()) { currentNode = currentNode.getParent(); Attribute currentAtt = newDataSet.attribute(currentNode.getName()); // change the following to refer to the array if (newValues[atts.indexOf(currentAtt)] == 1) // no need to go more up { break; } else // put 1 { newValues[atts.indexOf(currentAtt)] = 1; } } } } Instance instance = dataSet.instance(i); newDataSet.add(DataUtils.createInstance(instance, instance.weight(), newValues)); } return new MultiLabelInstances(newDataSet, metaData); } private void saveToArffFile(Instances dataSet, File file) throws IOException { ArffSaver saver = new ArffSaver(); saver.setInstances(dataSet); saver.setFile(file); saver.writeBatch(); } private void createXMLFile(LabelsMetaData metaData) throws Exception { DocumentBuilderFactory docBF = DocumentBuilderFactory.newInstance(); DocumentBuilder docBuilder = docBF.newDocumentBuilder(); labelsXMLDoc = docBuilder.newDocument(); Element rootElement = labelsXMLDoc.createElement("labels"); rootElement.setAttribute("xmlns", "http://mulan.sourceforge.net/labels"); labelsXMLDoc.appendChild(rootElement); for (LabelNode rootLabel : metaData.getRootLabels()) { Element newLabelElem = labelsXMLDoc.createElement("label"); newLabelElem.setAttribute("name", rootLabel.getName()); appendElement(newLabelElem, rootLabel); rootElement.appendChild(newLabelElem); } } private void saveToXMLFile(String fileName) { Source source = new DOMSource(labelsXMLDoc); File xmlFile = new File(fileName); StreamResult result = new StreamResult(xmlFile); try { Transformer transformer = TransformerFactory.newInstance().newTransformer(); transformer.setOutputProperty(OutputKeys.INDENT, "yes"); transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4"); transformer.setOutputProperty(OutputKeys.METHOD, "xml"); transformer.transform(source, result); } catch (Exception e) { e.printStackTrace(); } } private void appendElement(Element labelElem, LabelNode labelNode) { for (LabelNode childNode : labelNode.getChildren()) { Element newLabelElem = labelsXMLDoc.createElement("label"); newLabelElem.setAttribute("name", childNode.getName()); appendElement(newLabelElem, childNode); labelElem.appendChild(newLabelElem); } } public enum Method { Random, Clustering, BalancedClustering } }