package com.facebook.hive.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import java.util.*;
/**
* Calculate weighted percentile values. Here we use the definition given
* below under the section 'Weighted percentile'.
* http://en.wikipedia.org/wiki/Percentile
*
* Rows where either the weight or the value are NULL are discarded. The
* percentile argument should be an array of values between 0 and 1; an
* exception is thrown if this is not the case.
*
* Note that this is slightly different from a percentile computed by
* replicating each 'value' 'weight' times, even if 'weight' is integral. For
* example, suppose we are given the following two rows: (0, 1), (1, 99).
* Replicating and computing the median will yield 1 (under most median
* algorithms) whereas the algorithm below will yield a number slightly less
* than 1.
*/
@Description(name = "percentile",
value = "_FUNC_(value, weight, pc) - Returns the weighted percentiles at" +
" 'pc' of 'value' given 'weight'.")
public class UDAFWeightedPercentile extends UDAF {
/**
* A state class to store intermediate aggregation results.
*/
public static class State {
private Map<LongWritable, DoubleWritable> counts;
private List<DoubleWritable> percentiles;
}
/**
* A comparator to sort the entries in order.
*/
public static class MyComparator
implements Comparator<Map.Entry<LongWritable, DoubleWritable>> {
@Override
public int compare(Map.Entry<LongWritable, DoubleWritable> o1,
Map.Entry<LongWritable, DoubleWritable> o2) {
return o1.getKey().compareTo(o2.getKey());
}
}
/**
* Increment the State object with o as the key, and i as the count.
*/
private static void increment(State s, LongWritable o, double i) {
if (s.counts == null) {
s.counts = new HashMap<LongWritable, DoubleWritable>();
}
DoubleWritable count = s.counts.get(o);
if (count == null) {
// We have to create a new object, because the object o belongs
// to the code that creates it and may get its value changed.
LongWritable key = new LongWritable();
key.set(o.get());
s.counts.put(key, new DoubleWritable(i));
} else {
count.set(count.get() + i);
}
}
/**
* Get the percentile value. This follows the formula on Wikipedia under
* "Weighted percentile".
* Using that notation,
* p_n = 100 / S_N (S_n - w_n / 2)
* v = v_k + (p - p_k) / (p_{k+1} - p_k) (v_{k+1} - v_{k})
*
* 'position' here is equivalent to S_N p / 100, denote this by P
* Each entry in 'entriesList', e_n, is equal to S_n - w_n / 2
*/
private static double getPercentile(
List<Map.Entry<LongWritable, DoubleWritable>> entriesList,
double position) {
int k = 0;
// while p_k < p
// => p_k S_N / 100 < p S_N / 100 = P
// => (S_k - w_k / 2) < P
// => e_k < P
while (k < entriesList.size() &&
entriesList.get(k).getValue().get() < position) {
k++;
}
if (k == entriesList.size()) {
return entriesList.get(k - 1).getKey().get();
}
// p_k >= p
double e_k = entriesList.get(k).getValue().get();
long v_k = entriesList.get(k).getKey().get();
if (e_k == position || k == 0) {
return v_k;
}
// Need to interpolate since:
// p_{k - 1} < p < p_k
double e_km1 = entriesList.get(k - 1).getValue().get();
long v_km1 = entriesList.get(k - 1).getKey().get();
// p_{k+1} - p_k
// = 100 / S_N (e_{k+1} - e_k)
// p - p_k
// = 100 / S_N (P - e_k)
// (p - p_k) / (p_{k+1} - p_k) = (P - e_k) / (e_{k+1} - e_k)
return v_km1 + (position - e_km1) / (e_k - e_km1) * (v_k - v_km1);
}
/**
* The evaluator for percentile computation based on long for an array of
* percentiles.
*/
public static class PercentileLongArrayEvaluator implements UDAFEvaluator {
private final State state;
public PercentileLongArrayEvaluator() {
state = new State();
}
public void init() {
if (state.counts != null) {
// We reuse the same hashmap to reduce new object allocation.
// This means counts can be empty when there is no input data.
state.counts.clear();
}
}
public boolean iterate(LongWritable o, DoubleWritable w,
List<DoubleWritable> percentiles) {
if (state.percentiles == null) {
for (int i = 0; i < percentiles.size(); i++) {
if (percentiles.get(i).get() < 0.0 ||
percentiles.get(i).get() > 1.0) {
throw new RuntimeException("Percentile value must be in [0,1]");
}
}
state.percentiles = new ArrayList<DoubleWritable>(percentiles);
}
if (o != null) {
increment(state, o, 1);
}
return true;
}
public State terminatePartial() {
return state;
}
public boolean merge(State other) {
if (other == null || other.counts == null || other.percentiles == null) {
return true;
}
if (state.percentiles == null) {
state.percentiles = new ArrayList<DoubleWritable>(other.percentiles);
}
for (Map.Entry<LongWritable, DoubleWritable> e: other.counts.entrySet()) {
increment(state, e.getKey(), e.getValue().get());
}
return true;
}
private List<DoubleWritable> results;
public List<DoubleWritable> terminate() {
// No input data
if (state.counts == null || state.counts.size() == 0) {
return null;
}
// Get all items into an array and sort them
Set<Map.Entry<LongWritable, DoubleWritable>> entries =
state.counts.entrySet();
List<Map.Entry<LongWritable, DoubleWritable>> entriesList =
new ArrayList<Map.Entry<LongWritable, DoubleWritable>>(entries);
Collections.sort(entriesList, new MyComparator());
// accumulate the counts
double total = 0.0;
for (int i = 0; i < entriesList.size(); i++) {
DoubleWritable count = entriesList.get(i).getValue();
total += count.get();
count.set(total - count.get() / 2);
}
// Initialize the results
if (results == null) {
results = new ArrayList<DoubleWritable>();
for (int i = 0; i < state.percentiles.size(); i++) {
results.add(new DoubleWritable());
}
}
// Set the results
for (int i = 0; i < state.percentiles.size(); i++) {
double position = total * state.percentiles.get(i).get();
results.get(i).set(getPercentile(entriesList, position));
}
return results;
}
}
}