/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.example.set;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.RandomGenerator;
/**
* Creates a shuffled and stratified partition for an example set. The example
* set must have an nominal label. This partition builder can work in two modes:
* <ol>
* <li> The first working mode is automatically used for generic types of ratio
* arrays, especially for those with different sizes. Due to to this fact it
* however cannot longer be guaranteed that each fold exactly contains the
* correct number of examples and each class at least once. </li>
* <li> In contrast to the first mode the correct partition can at least be
* guaranteed for ratio arrays containing the same ratio value for all folds.
* The second mode is automatically performed in this case (e.g. for cross
* validation). </li>
* </ul>
*
* @author Ingo Mierswa
* ingomierswa Exp $
*/
public class StratifiedPartitionBuilder implements PartitionBuilder {
/** Helper class for sorting according to class values. */
private static class ExampleIndex implements Comparable<ExampleIndex> {
int exampleIndex;
String className;
public ExampleIndex(int exampleIndex, String className) {
this.exampleIndex = exampleIndex;
this.className = className;
}
public int compareTo(ExampleIndex e) {
return this.className.compareTo(e.className);
}
@Override
public boolean equals(Object o) {
if (!(o instanceof ExampleIndex)) {
return false;
} else {
ExampleIndex other = (ExampleIndex)o;
return (this.exampleIndex == other.exampleIndex);
}
}
@Override
public int hashCode() {
return Integer.valueOf(this.exampleIndex).hashCode();
}
@Override
public String toString() {
return exampleIndex + "(" + className + ")";
}
}
private ExampleSet exampleSet;
private Random random;
public StratifiedPartitionBuilder(ExampleSet exampleSet, boolean useLocalRandomSeed, int seed) {
this.exampleSet = exampleSet;
this.random = RandomGenerator.getRandomGenerator(useLocalRandomSeed, seed);
}
/**
* Returns a stratified partition for the given example set. The examples
* must have an nominal label.
*/
public int[] createPartition(double[] ratio, int size) {
Attribute label = exampleSet.getAttributes().getLabel();
// typical errors
if (size != exampleSet.size())
throw new RuntimeException("Cannot create stratified Partition: given size and size of the example set must be equal!");
if (label == null)
throw new RuntimeException("Cannot create stratified Partition: example set must have a label!");
if (!label.isNominal())
throw new RuntimeException("Cannot create stratified Partition: label of example set must be nominal!");
double firstValue = ratio[0];
for (int i = 1; i < ratio.length; i++)
if (ratio[i] != firstValue) {
LogService.getGlobal().log("Not all ratio values are equal: using non-equal stratified sampling.", LogService.STATUS);
return createNonEqualPartition(ratio, size, label);
}
LogService.getGlobal().log("All ratio values are equal: using stratified sampling.", LogService.STATUS);
return createEqualPartition(ratio, size, label);
}
/**
* Returns a stratified partition for the given example set. The examples
* must have a nominal label.
*/
private int[] createEqualPartition(double[] ratio, int size, Attribute label) {
// fill example list with indices and classes
List<ExampleIndex> examples = new ArrayList<ExampleIndex>(size);
Iterator<Example> reader = exampleSet.iterator();
int index = 0;
while (reader.hasNext()) {
Example example = reader.next();
examples.add(new ExampleIndex(index++, example.getNominalValue(label)));
}
// shuffling
Collections.shuffle(examples, random);
// sort by class
Collections.sort(examples);
// divide classes _equal_ into potential partitions
List<ExampleIndex> newExamples = new ArrayList<ExampleIndex>(size);
int start = 0;
int numberOfPartitions = ratio.length;
while (newExamples.size() < size) {
for (int i = start; i < examples.size(); i += numberOfPartitions) {
newExamples.add(examples.get(i));
}
start++;
}
// build partition starts
int[] startNewP = new int[ratio.length + 1];
startNewP[0] = 0;
double ratioSum = 0;
for (int i = 1; i < startNewP.length; i++) {
ratioSum += ratio[i - 1];
startNewP[i] = (int) Math.round(newExamples.size() * ratioSum);
}
// create a simple partition from the stratified shuffled example
// indices and partition starts
int[] part = new int[newExamples.size()];
int p = 0;
int counter = 0;
Iterator n = newExamples.iterator();
while (n.hasNext()) {
if (counter >= startNewP[p + 1])
p++;
ExampleIndex exampleIndex = (ExampleIndex) n.next();
part[exampleIndex.exampleIndex] = p;
counter++;
}
return part;
}
/**
* Returns a stratified partition for the given example set. The examples
* must have an nominal label. In contrast to
* {@link #createEqualPartition(double[], int, Attribute)} this method does
* not require the equal ratio values.
*/
private int[] createNonEqualPartition(double[] ratio, int size, Attribute label) {
// fill list with example indices for each class
Map<String, List<Integer>> classLists = new HashMap<String, List<Integer>>();
Iterator<Example> reader = exampleSet.iterator();
int index = 0;
while (reader.hasNext()) {
Example example = reader.next();
String value = example.getNominalValue(label);
List<Integer> classList = classLists.get(value);
if (classList == null) {
classList = new LinkedList<Integer>();
classList.add(index++);
classLists.put(value, classList);
} else {
classList.add(index++);
}
}
int[] part = new int[exampleSet.size()];
// shuffle each class list and create a partition for each class
// seperately
Iterator<List<Integer>> c = classLists.values().iterator();
while (c.hasNext()) {
List<Integer> classList = c.next();
// shuffle
Collections.shuffle(classList, random);
// build partition starts
int[] startNewP = new int[ratio.length + 1];
startNewP[0] = 0;
double ratioSum = 0;
for (int i = 1; i < startNewP.length; i++) {
ratioSum += ratio[i - 1];
startNewP[i] = (int) Math.round(classList.size() * ratioSum);
}
// create a simple partition from the shuffled example indices and
// partition starts
int p = 0;
int counter = 0;
Iterator n = classList.iterator();
while (n.hasNext()) {
if (counter >= startNewP[p + 1])
p++;
Integer exampleIndex = (Integer) n.next();
part[exampleIndex.intValue()] = p;
counter++;
}
}
return part;
}
}