/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.tools; import java.util.Enumeration; import java.util.Iterator; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.FastExample2SparseTransform; import com.rapidminer.example.Statistics; import com.rapidminer.operator.OperatorException; /** * This class extends the Weka class Instances and overrides all methods needed * to directly use a RapidMiner {@link ExampleSet} as source for Weka instead of * copying the complete data. * * @author Ingo Mierswa * @version $Id: WekaInstancesAdaptor.java,v 2.13 2006/04/06 15:23:38 * ingomierswa Exp $ */ public class WekaInstancesAdaptor extends Instances { private static final long serialVersionUID = 99943154106235423L; public static final int LEARNING = 0; public static final int PREDICTING = 1; public static final int CLUSTERING = 2; public static final int ASSOCIATION_RULE_MINING = 3; public static final int WEIGHTING = 4; /** * This enumeration implementation uses an ExampleReader (Iterator) to enumerate the * instances. */ private class InstanceEnumeration implements Enumeration { private Iterator<Example> reader; public InstanceEnumeration(Iterator<Example> reader) { this.reader = reader; } public Object nextElement() { return toWekaInstance(reader.next()); } public boolean hasMoreElements() { return reader.hasNext(); } } /** The example set which backs up the Instances object. */ private ExampleSet exampleSet; /** This transformation might help to speed up the creation of sparse examples. */ private transient FastExample2SparseTransform exampleTransform; /** The most frequent nominal values (only used for association rule mining, null otherwise). * -1 if attribute is numerical. */ private int[] mostFrequent = null; /** * The task type for which this instances object is used. Must be one out of * LEARNING, PREDICTING, CLUSTERING, ASSOCIATION_RULE_MINING, or WEIGHTING. For the * latter cases the original label attribute will be omitted. */ private int taskType = LEARNING; /** The label attribute or null if not desired (depending on task). */ private Attribute labelAttribute = null; /** The weight attribute or null if not available. */ private Attribute weightAttribute = null; /** Creates a new Instances object based on the given example set. */ public WekaInstancesAdaptor(String name, ExampleSet exampleSet, int taskType) throws OperatorException { super(name, getAttributeVector(exampleSet, taskType), 0); this.exampleSet = exampleSet; this.taskType = taskType; this.weightAttribute = exampleSet.getAttributes().getWeight(); this.exampleTransform = new FastExample2SparseTransform(exampleSet); switch (taskType) { case LEARNING: labelAttribute = exampleSet.getAttributes().getLabel(); setClassIndex(exampleSet.getAttributes().size()); break; case PREDICTING: labelAttribute = exampleSet.getAttributes().getPredictedLabel(); setClassIndex(exampleSet.getAttributes().size()); break; case CLUSTERING: labelAttribute = null; setClassIndex(-1); break; case ASSOCIATION_RULE_MINING: // in case of association learning the most frequent attribute // is needed to set to "unknown" exampleSet.recalculateAllAttributeStatistics(); this.mostFrequent = new int[exampleSet.getAttributes().size()]; int i = 0; for (Attribute attribute : exampleSet.getAttributes()) { if (attribute.isNominal()) { this.mostFrequent[i] = (int)exampleSet.getStatistics(attribute, Statistics.MODE); } else { this.mostFrequent[i] = -1; } i++; } labelAttribute = null; setClassIndex(-1); break; case WEIGHTING: labelAttribute = exampleSet.getAttributes().getLabel(); if (labelAttribute != null) setClassIndex(exampleSet.getAttributes().size()); break; } } protected Object readResolve() { try { this.exampleTransform = new FastExample2SparseTransform(this.exampleSet); } catch (OperatorException e) { // do nothing } return this; } // ================================================================================ // Overriding some Weka methods // ================================================================================ /** Returns an instance enumeration based on an ExampleReader. */ public Enumeration enumerateInstances() { return new InstanceEnumeration(exampleSet.iterator()); } /** Returns the i-th instance. */ public Instance instance(int i) { return toWekaInstance(exampleSet.getExample(i)); } /** Returns the number of instances. */ public int numInstances() { return exampleSet.size(); } // ================================================================================ // Transforming examples into Weka instances // ================================================================================ /** Gets an example and creates a Weka instance. */ private Instance toWekaInstance(Example example) { int numberOfRegularValues = example.getAttributes().size(); int numberOfValues = numberOfRegularValues + (labelAttribute != null ? 1 : 0); double[] values = new double[numberOfValues]; // set regular attribute values if (taskType == ASSOCIATION_RULE_MINING) { int a = 0; for (Attribute attribute : exampleSet.getAttributes()) { double value = example.getValue(attribute); if (attribute.isNominal()) { if (value == mostFrequent[a]) value = Double.NaN; // sets the most frequent value to missing // for association learning } values[a] = value; a++; } } else { int[] nonDefaultIndices = exampleTransform.getNonDefaultAttributeIndices(example); double[] nonDefaultValues = exampleTransform.getNonDefaultAttributeValues(example, nonDefaultIndices); for (int a = 0; a < nonDefaultIndices.length; a++) { values[nonDefaultIndices[a]] = nonDefaultValues[a]; } } // set label value if necessary switch (taskType) { case LEARNING: values[values.length - 1] = example.getValue(labelAttribute); break; case PREDICTING: values[values.length - 1] = Double.NaN; break; case WEIGHTING: if (labelAttribute != null) values[values.length - 1] = example.getValue(labelAttribute); break; default: break; } // get instance weight double weight = 1.0d; if (this.weightAttribute != null) weight = example.getValue(this.weightAttribute); // create new instance Instance instance = new Instance(weight, values); instance.setDataset(this); return instance; } // ================================================================================ private static FastVector getAttributeVector(ExampleSet exampleSet, int taskType) { // determine label Attribute label = null; switch (taskType) { case LEARNING: case WEIGHTING: label = exampleSet.getAttributes().getLabel(); break; case PREDICTING: label = exampleSet.getAttributes().getPredictedLabel(); break; default: break; } // add regular attributes FastVector attributeVector = new FastVector(exampleSet.getAttributes().size() + (label != null ? 1 : 0)); for (Attribute attribute : exampleSet.getAttributes()) { attributeVector.addElement(toWekaAttribute(attribute)); } // add label if (label != null) attributeVector.addElement(toWekaAttribute(label)); return attributeVector; } /** Converts an {@link Attribute} to a Weka attribute. */ private static weka.core.Attribute toWekaAttribute(Attribute attribute) { if (attribute == null) return null; weka.core.Attribute result = null; if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(attribute.getValueType(), Ontology.NOMINAL)) { FastVector nominalValues = new FastVector(attribute.getMapping().getValues().size()); for (int i = 0; i < attribute.getMapping().getValues().size(); i++) { nominalValues.addElement(attribute.getMapping().mapIndex(i)); } result = new weka.core.Attribute(attribute.getName(), nominalValues); } else { result = new weka.core.Attribute(attribute.getName()); } return result; } }