/*
* RapidMiner
*
* Copyright (C) 2001-2007 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package com.rapidminer.operator.preprocessing.discretization;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.tools.Ontology;
/**
* A filter that discretizes all numeric attributes in the dataset into nominal
* attributes. The discretization is performed by selecting a bin boundary
* minimizing the entropy in the induced partitions. The method is then applied
* recursively for both new partitions until the stopping criterion is reached.
* For Detail see a)Multi-interval discretization of continued-values attributes
* for classification learning(Fayyad,Irani) b)Supervised and Unsupervized
* Discretization(Dougherty,Kohavi,Sahami) Skips all special attributes
* including the label.
*
* @author Dirk Dach
* @version $Id: MinimalEntropyPartitioning.java,v 1.5 2006/04/14 11:42:27
* ingomierswa Exp $
*/
public class MinimalEntropyPartitioning extends Discretization {
public MinimalEntropyPartitioning(OperatorDescription description) {
super(description);
}
private Double getMinEntropySplitpoint(LinkedList<double[]> truncatedExamples, Attribute label) {
HashSet<Double> candidateSplitpoints = new HashSet<Double>();
Iterator<double[]> it = truncatedExamples.iterator();
int[] totalLabelDistribution = new int[label.getMapping().size()]; // Label
// distribution
// for
// all
// examples.
while (it.hasNext()) { // Get splitpoint candidates and total label
// distribution.
double[] attributeLabelPair = it.next();
candidateSplitpoints.add(attributeLabelPair[0]);
int labelIndex = (int) attributeLabelPair[1];
totalLabelDistribution[labelIndex]++;
}
double[] totalFrequencies = new double[label.getMapping().size()];
for (int i = 0; i < label.getMapping().size(); i++) {
totalFrequencies[i] = (double) totalLabelDistribution[i] / (double) truncatedExamples.size();
}
double totalEntropy = 0.0d;
for (int i = 0; i < label.getMapping().size(); i++) {
totalEntropy -= totalFrequencies[i] * log2(totalFrequencies[i]);
}
double minClassInformationEntropy = totalEntropy;
double bestSplitpoint = Double.NaN;
double bestSplitpointEntropy1 = Double.POSITIVE_INFINITY;
double bestSplitpointEntropy2 = Double.POSITIVE_INFINITY;
int k1 = 0; // Number of different class labels in class 1.
int k2 = 0; // Number of different class labels in class 2.
Iterator it1 = candidateSplitpoints.iterator();
while (it1.hasNext()) { // Test every value as splitpoint
double currentSplitpoint = ((Double) it1.next()).doubleValue();
// Initialize.
int s1 = 0; // Instances in partition 1.
int s2 = 0; // Instances in partition 2.
k1 = 0;
k2 = 0;
int[] labelDistribution1 = new int[label.getMapping().size()]; // Label
// distribution
// in
// class
// 1.
int[] labelDistribution2 = new int[label.getMapping().size()]; // Label
// distribution
// in
// class
// 2.
// Determine the class of each instance and the corresponding label
// distribution.
Iterator it2 = truncatedExamples.iterator();
while (it2.hasNext()) {
double[] attributeLabelPair = (double[]) it2.next();
double valueToCompare = attributeLabelPair[0];
int labelIndex = (int) attributeLabelPair[1];
if (valueToCompare <= currentSplitpoint) { // Partition 1 gets
// all instances
// with values less
// or equal to the
// current
// splitpoint.
s1++;
labelDistribution1[labelIndex]++;
} else { // Partition 2 gets all instances with values
// greater than the current split point.
s2++;
labelDistribution2[labelIndex]++;
}
}
// Calculate frequencies and number of different labels for this
// splitpoint each class.
double[] frequencies1 = new double[label.getMapping().size()];
double[] frequencies2 = new double[label.getMapping().size()];
for (int i = 0; i < label.getMapping().size(); i++) {
frequencies1[i] = (double) labelDistribution1[i] / (double) s1;
frequencies2[i] = (double) labelDistribution2[i] / (double) s2;
if (labelDistribution1[i] > 0) { // Label value i exists in
// class 1.
k1++;
}
if (labelDistribution2[i] > 0) { // Label value i exists in
// class 2.
k2++;
}
}
// Calculate entropies.
double entropy1 = 0.0d;
for (int i = 0; i < label.getMapping().size(); i++) {
entropy1 -= frequencies1[i] * log2(frequencies1[i]);
}
double entropy2 = 0.0d;
for (int i = 0; i < label.getMapping().size(); i++) {
entropy2 -= frequencies2[i] * log2(frequencies2[i]);
}
double classInformationEntropy = ((double) s1 / (double) truncatedExamples.size()) * entropy1 + ((double) s2 / (double) truncatedExamples.size()) * entropy2;
if (classInformationEntropy < minClassInformationEntropy) {
minClassInformationEntropy = classInformationEntropy;
bestSplitpoint = currentSplitpoint;
bestSplitpointEntropy1 = entropy1;
bestSplitpointEntropy2 = entropy2;
}
}
// Calculate the termination criterion. Return null if termination
// criterion is met.
double gain = totalEntropy - minClassInformationEntropy;
double delta = log2(Math.pow(3.0, label.getMapping().size()) - 2) - (label.getMapping().size() * totalEntropy - k1 * bestSplitpointEntropy1 - k2 * bestSplitpointEntropy2);
if (gain >= log2(truncatedExamples.size() - 1) / truncatedExamples.size() + delta / truncatedExamples.size()) {
return new Double(bestSplitpoint);
} else {
return null;
}
}
/*
* LinkedList partition consist of double arrays of size 2. array[0]=value
* of the current attribute, array[1]=corresponding label value.
*/
private ArrayList getSplitpoints(LinkedList<double[]> startPartition, Attribute label) {
LinkedList<LinkedList<double[]>> border = new LinkedList<LinkedList<double[]>>();
ArrayList<Double> result = new ArrayList<Double>();
border.addLast(startPartition);
while (!border.isEmpty()) {
LinkedList<double[]> currentPartition = border.removeFirst();
Double splitpoint = this.getMinEntropySplitpoint(currentPartition, label);
if (splitpoint != null) {
result.add(splitpoint);
double splitValue = splitpoint.doubleValue();
LinkedList<double[]> newPartition1 = new LinkedList<double[]>();
LinkedList<double[]> newPartition2 = new LinkedList<double[]>();
Iterator<double[]> it = currentPartition.iterator();
while (it.hasNext()) { // Create new partitions.
double[] attributeLabelPair = it.next();
if (attributeLabelPair[0] <= splitValue) {
newPartition1.addLast(attributeLabelPair);
} else {
newPartition2.addLast(attributeLabelPair);
}
}
border.addLast(newPartition1);
border.addLast(newPartition2);
}
}
return result; // Empty ArrayList if no Splitpoint could be found.
}
/**
* Delivers the maximum range thresholds for all attributes, i.e. the value
* getRanges()[a][b] is the b-th threshold for the a-th attribute.
*/
public double[][] getRanges(ExampleSet exampleSet) {
double[][] ranges = new double[exampleSet.getAttributes().size()][];
Attribute label = exampleSet.getAttributes().getLabel();
int a = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
if (!attribute.isNominal()) { // skip nominal attributes
Iterator<Example> reader = exampleSet.iterator();
LinkedList<double[]> startPartition = new LinkedList<double[]>();
while (reader.hasNext()) { // Create start partition.
Example example = reader.next();
double[] attributeLabelPair = new double[2];
attributeLabelPair[0] = example.getValue(attribute);
attributeLabelPair[1] = example.getValue(label);
startPartition.addLast(attributeLabelPair);
}
ArrayList splitpointsOfAttribute = getSplitpoints(startPartition, label);
Iterator it = splitpointsOfAttribute.iterator();
ranges[a] = new double[splitpointsOfAttribute.size() + 1];
for (int i = 0; it.hasNext(); i++) {
ranges[a][i] = ((Double) it.next()).doubleValue();
}
ranges[a][ranges[a].length - 1] = exampleSet.getStatistics(attribute, Statistics.MAXIMUM);
Arrays.sort(ranges[a]);
}
a++;
}
return ranges;
}
public IOObject[] apply() throws OperatorException {
ExampleSet exampleSet = getInput(ExampleSet.class);
Attribute label = exampleSet.getAttributes().getLabel();
if ((label == null) || (!label.isNominal()))
throw new UserError(this, 101, getName(), (label == null ? "no label" : label.getName()));
exampleSet.recalculateAllAttributeStatistics();
checkForStop();
double[][] ranges = getRanges(exampleSet);
boolean[] numerical = new boolean[ranges.length]; // needed since
// value type is
// changed!
// change attribute type
int a = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
if (!attribute.isNominal()) { // skip nominal attributes
numerical[a] = true;
attribute = exampleSet.getAttributes().replace(attribute, AttributeFactory.changeValueType(attribute, Ontology.NOMINAL));
for (int b = 0; b < ranges[a].length; b++) {
attribute.getMapping().mapString("range" + (b + 1));
}
} else {
numerical[a] = false;
}
a++;
}
// change data
Iterator<Example> reader = exampleSet.iterator();
while (reader.hasNext()) {
Example example = reader.next();
a = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
if (numerical[a] && ranges[a] != null) {
double value = example.getValue(attribute);
for (int b = 0; b < ranges[a].length; b++) {
if (value <= ranges[a][b]) {
example.setValue(attribute, attribute.getMapping().mapString("range" + (b + 1)));
break;
}
}
}
a++;
}
checkForStop();
}
// remove useless attributes with no splitpoint
a = 0;
Iterator<Attribute> i = exampleSet.getAttributes().iterator();
while (i.hasNext()) {
i.next();
if (numerical[a] && ranges[a].length == 1) {
i.remove();
}
a++;
}
return new IOObject[] { exampleSet };
}
public double log2(double arg) {
return Math.log(arg) / Math.log(2);
}
public Class[] getOutputClasses() {
return new Class[] { ExampleSet.class };
}
public Class[] getInputClasses() {
return new Class[] { ExampleSet.class };
}
}