/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.preprocessing.transformation.aggregation;
import java.util.Arrays;
/**
* This is an {@link Aggregator} for the {@link MeanAggregationFunction}. It uses a variation of the
* quickselect algorithm for computing the median in an average time of O(n). In case the number of
* unweighted elements is even or the midpoint of the weights lies between two elements, the
* midpoint of the both middle values will be returned as the median. The memory consumption will
* grow linearly with the size of the dataset.
*
*
* @author Marcel Seifert
* @since 7.5
*/
public class MedianAggregator extends NumericalAggregator {
/**
* This class implements an array of primitive doubles and provides getter, adder and size
* methods. It is used by the {@link MedianAggregator} as a lightweight data structure.
*/
private static class VariableDoubleArray {
private static final int INITIAL_ARRAY_SIZE = 64;
private int size = 0;
private double[] data;
public VariableDoubleArray() {
data = new double[INITIAL_ARRAY_SIZE];
}
public int size() {
return size;
}
public double[] getArray() {
return data;
}
public void add(double value) {
if (data.length == size) {
int newSize = size + (size >> 2);
data = Arrays.copyOf(data, newSize);
}
data[size] = value;
size++;
}
}
private VariableDoubleArray values = null;
private VariableDoubleArray weights = null;
private int count = 0;
private double weightCount = 0;
public MedianAggregator(AggregationFunction function) {
super(function);
}
@Override
public void count(double value) {
if (count == 0) {
values = new VariableDoubleArray();
}
values.add(value);
count++;
}
@Override
public void count(double value, double weight) {
if (count == 0) {
values = new VariableDoubleArray();
weights = new VariableDoubleArray();
}
values.add(value);
weights.add(weight);
count++;
weightCount += weight;
}
@Override
public double getValue() {
// The Median is NaN
if (count == 0) {
return Double.NaN;
}
if (weights == null) {
return quickNth(values, count / 2.0);
} else {
return quickNthWeighted(values, weights, weightCount / 2.0);
}
}
/**
* Implements a variation of quickSelect. Selects the value which contains the the nth weight.
* If n is the weight between two values, the middlepoint of these two values will be returned.
*
* @param values
* The values as a {@link VariableDoubleArray}
* @param n
* The nth value will be selected
* @return The nth value
*/
private double quickNth(VariableDoubleArray values, double n) {
// Choose pivot from the middle of the list
double pivot = values.getArray()[values.size() / 2];
// Split into smaller equal and greater list
VariableDoubleArray smallerValues = new VariableDoubleArray();
VariableDoubleArray greaterValues = new VariableDoubleArray();
int equalCount = 0;
for (int i = 0; i < values.size(); i++) {
double currentElement = values.getArray()[i];
if (currentElement < pivot) {
smallerValues.add(currentElement);
} else if (currentElement > pivot) {
greaterValues.add(currentElement);
} else {
equalCount++;
}
}
// Median between two different lists -> Median is midpoint of greatest value of smaller
// list and smallest value of greater list
if (smallerValues.size() == n) {
double max = Double.MIN_VALUE;
for (int i = 0; i < smallerValues.size(); i++) {
if (smallerValues.getArray()[i] > max) {
max = smallerValues.getArray()[i];
}
}
return (pivot + max) / 2;
} else if (smallerValues.size() + equalCount == n) {
double min = Double.MAX_VALUE;
for (int i = 0; i < greaterValues.size(); i++) {
if (greaterValues.getArray()[i] < min) {
min = greaterValues.getArray()[i];
}
}
return (pivot + min) / 2;
}
// Check which of the three lists contains median and return it or adjust n
else if (smallerValues.size() >= n) {
return quickNth(smallerValues, n);
} else if (smallerValues.size() + equalCount > n) {
return pivot;
} else {
return quickNth(greaterValues, n - smallerValues.size() - equalCount);
}
}
/**
* Implements a variation of quickSelect. Selects the value which contains the the nth weight.
* If n is the weight between two values, the middlepoint of these two values will be returned.
*
* @param values
* The values as a {@link VariableDoubleArray}
* @param weights
* The weights as a {@link VariableDoubleArray}
* @param n
* The nth value will be selected
* @return The nth value
*/
private double quickNthWeighted(VariableDoubleArray values, VariableDoubleArray weights, double n) {
double pivot = values.getArray()[values.size() / 2];
// Split into smaller equal and greater list
VariableDoubleArray smallerValues = new VariableDoubleArray();
VariableDoubleArray greaterValues = new VariableDoubleArray();
VariableDoubleArray smallerWeights = new VariableDoubleArray();
VariableDoubleArray greaterWeights = new VariableDoubleArray();
double smallerWeightCount = 0;
double equalWeightCount = 0;
for (int i = 0; i < values.size(); i++) {
double currentElement = values.getArray()[i];
double currentWeight = weights.getArray()[i];
if (currentElement < pivot) {
smallerValues.add(currentElement);
smallerWeights.add(currentWeight);
smallerWeightCount += currentWeight;
} else if (currentElement > pivot) {
greaterValues.add(currentElement);
greaterWeights.add(currentWeight);
} else {
equalWeightCount += currentWeight;
}
}
// Median between two different lists -> Median is midpoint of greatest value of smaller
// list and smallest value of greater list
if (smallerWeightCount == n) {
double max = Double.MIN_VALUE;
for (int i = 0; i < smallerValues.size(); i++) {
if (smallerValues.getArray()[i] > max) {
max = smallerValues.getArray()[i];
}
}
return (pivot + max) / 2;
} else if (smallerWeightCount + equalWeightCount == n) {
double min = Double.MAX_VALUE;
for (int i = 0; i < greaterValues.size(); i++) {
if (greaterValues.getArray()[i] < min) {
min = greaterValues.getArray()[i];
}
}
return (pivot + min) / 2;
}
// Check which of the three lists contains median and return it or adjust n
else if (smallerWeightCount >= n) {
return quickNthWeighted(smallerValues, smallerWeights, n);
} else if (smallerWeightCount + equalWeightCount > n) {
return pivot;
} else {
return quickNthWeighted(greaterValues, greaterWeights, n - smallerWeightCount - equalWeightCount);
}
}
}