/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.example; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.io.PrintWriter; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import org.xml.sax.SAXException; import com.rapidminer.RapidMiner; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.SimpleDataTable; import com.rapidminer.datatable.SimpleDataTableRow; import com.rapidminer.tools.Tools; import com.rapidminer.tools.math.AverageVector; /** * AttributeWeights holds the information about the weights of attributes of an example set. It is delivered by several * feature weighting algorithms or learning schemes. The use of a linked hash map ensures that the added features are * stored in the same sequence they were added. * * @author Ingo Mierswa */ public class AttributeWeights extends AverageVector { private static final long serialVersionUID = 7000978931118131854L; /** Indicates that the weights should not be sorted at all. */ public static final int NO_SORTING = 0; /** Indicates that the weights should be sorted in descending order. */ public static final int DECREASING = -1; /** Indicates that the weights should be sorted in ascending order. */ public static final int INCREASING = 1; /** Indicates that the the actual weights should be used for sorting. */ public static final int ORIGINAL_WEIGHTS = 0; /** Indicates that the the absolute weights should be used for sorting. */ public static final int ABSOLUTE_WEIGHTS = 1; /** This comparator sorts the names of attributes according to their weights. */ private class WeightComparator implements Comparator<String> { /** Indicates if absolute weights should be used for sorting. */ private final int comparatorWeightType; /** Indicates the sorting direction. */ private final int direction; /** Creates a new weight comparator. */ public WeightComparator(int direction, int comparatorWeightType) { this.comparatorWeightType = comparatorWeightType; this.direction = direction; } /** Creates two attribute weights. */ @Override public int compare(String o1, String o2) { double w1 = weightMap.get(o1).getWeight(); double w2 = weightMap.get(o2).getWeight(); if (comparatorWeightType == ABSOLUTE_WEIGHTS) { w1 = Math.abs(w1); w2 = Math.abs(w2); } return Double.compare(w1, w2) * direction; } } // ================================================================================ /** Indicates the type of sorting. */ private int sortType = NO_SORTING; /** Indicates if absolute or actual weights should be used for sorting. */ private int weightType = ORIGINAL_WEIGHTS; /** Maps the name of an attribute to the corresponding attribute weight. */ private Map<String, AttributeWeight> weightMap = new LinkedHashMap<String, AttributeWeight>(); /** Creates a new empty attribute weights object. */ public AttributeWeights() { } /** * Creates a new attribute weights object containing a weight of 1 for each of the given input attributes. */ public AttributeWeights(ExampleSet exampleSet) { for (Attribute attribute : exampleSet.getAttributes()) setWeight(attribute.getName(), 1.0d); } /** Clone constructor. */ private AttributeWeights(AttributeWeights weights) { super(); Iterator i = weights.getAttributeNames().iterator(); while (i.hasNext()) { String name = (String) i.next(); this.setWeight(name, weights.getWeight(name)); } } /** Returns the name of this AverageVector. */ @Override public String getName() { return "AttributeWeights"; } /** Sets the weight for the attribute with the given name. */ public void setWeight(String name, double weight) { AttributeWeight oldWeight = weightMap.get(name); if (Double.isNaN(weight)) { weightMap.remove(name); super.removeAveragable(oldWeight); } else if (oldWeight == null) { AttributeWeight attWeight = new AttributeWeight(this, name, weight); super.addAveragable(attWeight); weightMap.put(name, attWeight); } else { oldWeight.setWeight(weight); } } /** * Returns the weight for the attribute with the given name. Returns Double.NaN if the weight for the queried * attribute is not known. */ public double getWeight(String name) { AttributeWeight weight = weightMap.get(name); if (weight == null) return Double.NaN; else return weight.getWeight(); } /** Returns the currently used weight type. */ public int getWeightType() { return weightType; } /** Returns the currently used weight type. */ public void setWeightType(int weightType) { this.weightType = weightType; } /** Returns the currently used sorting type. */ public int getSortingType() { return sortType; } /** Sets the currently used sorting type. */ public void setSortingType(int sortingType) { this.sortType = sortingType; } /** Returns the number of features in this map. */ @Override public int size() { return weightMap.size(); } /** * This method removes the given attribute weight from this object. */ public void removeAttributeWeight(String attributeName) { this.weightMap.remove(attributeName); } /** * Returns an set of attribute names in this map ordered by their insertion time. */ public Set<String> getAttributeNames() { return weightMap.keySet(); } /** * Since this average vector cannot be compared this method always returns 0. */ @Override public int compareTo(Object o) { return 0; } /** Returns true if both objects have the same weight map. */ @Override public boolean equals(Object o) { if (!(o instanceof AttributeWeights)) { return false; } else { AttributeWeights other = (AttributeWeights) o; return this.weightMap.equals(other.weightMap); } } /** Returns the hash code of the weight map. */ @Override public int hashCode() { return this.weightMap.hashCode(); } /** * Sorts the given array of attribute names according to their weight, the sorting direction (ascending or * descending), and with respect to the fact if original or absolute weights should be used. Ascending means that * the attributes with the smallest weights come first. * * @param direction * <code>INCREASING</code> or <code>DECREASING</code> * @param comparatorType * <code>WEIGHT</code> or <code>WEIGHT_ABSOLUTE</code>. */ public void sortByWeight(String[] attributeNames, int direction, int comparatorType) { Arrays.sort(attributeNames, new WeightComparator(direction, comparatorType)); } /** * This will sort the weights either ascending if the boolean flag is true or descending. * @param direction * <code>ASCENDING</code> or <code>DESCENDING</code> * * @param comparatorType * <code>WEIGHT</code> or <code>WEIGHT_ABSOLUTE</code>. */ public void sort(int direction, int comparatorType) { Map<String, AttributeWeight> newWeightMap = new LinkedHashMap<String, AttributeWeight>(); ArrayList<String> attributes = new ArrayList<String>(weightMap.keySet()); Collections.sort(attributes, new WeightComparator(direction, comparatorType)); for (String attributeName: attributes) { newWeightMap.put(attributeName, weightMap.get(attributeName)); } weightMap = newWeightMap; } /** Saves the attribute weights into an XML file. */ public void save(File file) throws IOException { writeAttributeWeights(file, Tools.getDefaultEncoding()); } public void writeAttributeWeights(File file, Charset encoding) throws IOException { PrintWriter out = null; try { out = new PrintWriter(new FileWriter(file)); out.println("<?xml version=\"1.0\" encoding=\"" + encoding + "\"?>"); out.println("<attributeweights version=\"" + RapidMiner.getShortVersion() + "\">"); Iterator i = weightMap.keySet().iterator(); while (i.hasNext()) { String key = (String) i.next(); double weight = weightMap.get(key).getWeight(); out.println(" <weight name=\"" + key + "\" value=\"" + weight + "\"/>"); } out.println("</attributeweights>"); } catch (IOException e) { throw e; } finally { if (out != null) { out.close(); } } } /** Loads a new AttributeWeights object from the given XML file. */ public static AttributeWeights load(File file) throws IOException { AttributeWeights result = new AttributeWeights(); Document document = null; try { document = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(file); } catch (SAXException e1) { throw new IOException(e1.getMessage()); } catch (ParserConfigurationException e1) { throw new IOException(e1.getMessage()); } Element attributeWeightsElement = document.getDocumentElement(); if (!attributeWeightsElement.getTagName().equals("attributeweights")) { throw new IOException("Outer tag of attribute weights file must be <attributeweights>"); } NodeList weights = attributeWeightsElement.getChildNodes(); for (int i = 0; i < weights.getLength(); i++) { Node node = weights.item(i); if (node instanceof Element) { Element weightTag = (Element) node; String tagName = weightTag.getTagName(); if (!tagName.equals("weight")) throw new IOException("Only tags <weight> are allowed, was " + tagName); String name = weightTag.getAttribute("name"); String value = weightTag.getAttribute("value"); double weight = 1.0d; try { weight = Double.parseDouble(value); } catch (NumberFormatException e) { throw new IOException("Only numerical weights are allowed for the 'value' attribute."); } result.setWeight(name, weight); } } return result; } public String getExtension() { return "wgt"; } public String getFileDescription() { return "attribute weights file"; } /** Returns a string representation of this object. */ @Override public String toString() { return "AttributeWeights (containing weights for " + weightMap.size() + " attributes)"; } /** * Returns a deep clone of the attribute weights which provides the same sequence of attribute names. */ @Override public Object clone() { return new AttributeWeights(this); } /** This method normalizes all weights to the range 0 to 1. */ public void normalize() { double weightMin = Double.POSITIVE_INFINITY; double weightMax = Double.NEGATIVE_INFINITY; for (String name : getAttributeNames()) { double weight = Math.abs(getWeight(name)); weightMin = Math.min(weightMin, weight); weightMax = Math.max(weightMax, weight); } Iterator<AttributeWeight> w = weightMap.values().iterator(); double diff = weightMax - weightMin; while (w.hasNext()) { AttributeWeight attributeWeight = w.next(); double newWeight = 1.0d; if (diff != 0.0d) newWeight = (Math.abs(attributeWeight.getWeight()) - weightMin) / diff; attributeWeight.setWeight(newWeight); } } public DataTable createDataTable() { DataTable dataTable = new SimpleDataTable("Attribute Weights", new String[] { "attribute", "weight" }); for (Map.Entry<String, AttributeWeight> entry : weightMap.entrySet()) { String attName = entry.getKey(); AttributeWeight attWeight = entry.getValue(); double index = dataTable.mapString(0, attName); double weightValue = attWeight.getWeight(); double[] data = new double[] { index, weightValue }; dataTable.add(new SimpleDataTableRow(data, attName)); } return dataTable; } }