package hex;
import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
public class GainsLift extends Iced {
private double[] _quantiles;
//INPUT
public int _groups = -1;
public Vec _labels;
public Vec _preds; //of length N, n_i = N/GROUPS
public Vec _weights;
//OUTPUT
public double[] response_rates; // p_i = e_i/n_i
public double avg_response_rate; // P
public long[] events; // e_i
public long[] observations; // n_i
TwoDimTable table;
public GainsLift(Vec preds, Vec labels) {
this(preds, labels, null);
}
public GainsLift(Vec preds, Vec labels, Vec weights) {
_preds = preds;
_labels = labels;
_weights = weights;
}
private void init(Job job) throws IllegalArgumentException {
_labels = _labels.toCategoricalVec();
if( _labels ==null || _preds ==null )
throw new IllegalArgumentException("Missing actualLabels or predictedProbs!");
if (_labels.length() != _preds.length())
throw new IllegalArgumentException("Both arguments must have the same length ("+ _labels.length()+"!="+ _preds.length()+")!");
if (!_labels.isInt())
throw new IllegalArgumentException("Actual column must be integer class labels!");
if (_labels.cardinality() != -1 && _labels.cardinality() != 2)
throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + _labels.cardinality() + "!");
if (_preds.isCategorical())
throw new IllegalArgumentException("Predicted probabilities cannot be class labels, expect probabilities.");
if (_weights != null && !_weights.isNumeric())
throw new IllegalArgumentException("Observation weights must be numeric.");
// The vectors are from different groups => align them, but properly delete it after computation
if (!_labels.group().equals(_preds.group())) {
_preds = _labels.align(_preds);
Scope.track(_preds);
if (_weights !=null) {
_weights = _labels.align(_weights);
Scope.track(_weights);
}
}
boolean fast = false;
if (fast) {
// FAST VERSION: single-pass, only works with the specific pre-computed quantiles from rollupstats
assert(_groups == 10);
assert(Arrays.equals(Vec.PERCENTILES,
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16
new double[]{0.001, 0.01, 0.1, 0.2, 0.25, 0.3, 1.0 / 3.0, 0.4, 0.5, 0.6, 2.0 / 3.0, 0.7, 0.75, 0.8, 0.9, 0.99, 0.999}));
//HACK: hardcoded quantiles for simplicity (0.9,0.8,...,0.1,0)
double[] rq = _preds.pctiles(); //might do a full pass over the Vec
_quantiles = new double[]{
rq[14], rq[13], rq[11], rq[9], rq[8], rq[7], rq[5], rq[3], rq[2], 0 /*ignored*/
};
} else {
// ACCURATE VERSION: multi-pass
Frame fr = null;
QuantileModel qm = null;
try {
QuantileModel.QuantileParameters qp = new QuantileModel.QuantileParameters();
if (_weights==null) {
fr = new Frame(Key.<Frame>make(), new String[]{"predictions"}, new Vec[]{_preds});
} else {
fr = new Frame(Key.<Frame>make(), new String[]{"predictions", "weights"}, new Vec[]{_preds, _weights});
qp._weights_column = "weights";
}
DKV.put(fr);
qp._train = fr._key;
if (_groups > 0) {
qp._probs = new double[_groups];
for (int i = 0; i < _groups; ++i) {
qp._probs[i] = (_groups - i - 1.) / _groups; // This is 0.9, 0.8, 0.7, 0.6, ..., 0.1, 0 for 10 groups
}
} else {
qp._probs = new double[]{0.99, 0.98, 0.97, 0.96, 0.95, 0.9, 0.85, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0};
}
qm = job != null && !job.isDone() ? new Quantile(qp, job).trainModelNested(null) : new Quantile(qp).trainModel().get();
_quantiles = qm._output._quantiles[0];
// find uniques (is there a more elegant way?)
TreeSet<Double> hs = new TreeSet<>();
for (double d : _quantiles) hs.add(d);
_quantiles = new double[hs.size()];
Iterator<Double> it = hs.descendingIterator();
int i = 0;
while (it.hasNext()) _quantiles[i++] = it.next();
} finally {
if (qm!=null) qm.remove();
if (fr!=null) DKV.remove(fr._key);
}
}
}
public void exec() {
exec(null);
}
public void exec(Job job) {
Scope.enter();
init(job); //check parameters and obtain _quantiles from _preds
try {
GainsLiftBuilder gt = new GainsLiftBuilder(_quantiles);
gt = (_weights != null) ? gt.doAll(_labels, _preds, _weights) : gt.doAll(_labels, _preds);
response_rates = gt.response_rates();
avg_response_rate = gt.avg_response_rate();
events = gt.events();
observations = gt.observations();
} finally { // Delete adaptation vectors
Scope.exit();
}
}
@Override public String toString() {
TwoDimTable t = createTwoDimTable();
return t==null ? "" : t.toString();
}
public TwoDimTable createTwoDimTable() {
if (response_rates == null || Double.isNaN(avg_response_rate)) return null;
TwoDimTable table = new TwoDimTable(
"Gains/Lift Table",
"Avg response rate: " + PrettyPrint.formatPct(avg_response_rate),
new String[events.length],
new String[]{"Group", "Cumulative Data Fraction", "Lower Threshold", "Lift", "Cumulative Lift", "Response Rate", "Cumulative Response Rate", "Capture Rate", "Cumulative Capture Rate", "Gain", "Cumulative Gain"},
new String[]{"int", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double"},
new String[]{"%d", "%.8f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f"},
"");
long sum_e_i = 0;
long sum_n_i = 0;
double P = avg_response_rate; // E/N
long N = ArrayUtils.sum(observations);
long E = Math.round(N * P);
for (int i = 0; i < events.length; ++i) {
long e_i = events[i];
long n_i = observations[i];
double p_i = response_rates[i];
sum_e_i += e_i;
sum_n_i += n_i;
double lift=p_i/P; //can be NaN if P==0
double sum_lift=(double)sum_e_i/sum_n_i/P; //can be NaN if P==0
table.set(i,0,i+1); //group
table.set(i,1,(double)sum_n_i/N); //cumulative_data_fraction
table.set(i,2,_quantiles[i]); //lower_threshold
table.set(i,3,lift); //lift
table.set(i,4,sum_lift); //cumulative_lift
table.set(i,5,p_i); //response_rate
table.set(i,6,(double)sum_e_i/sum_n_i); //cumulative_response_rate
table.set(i,7,(double)e_i/E); //capture_rate
table.set(i,8,(double)sum_e_i/E); //cumulative_capture_rate
table.set(i,9,100*(lift-1)); //gain
table.set(i,10,100*(sum_lift-1)); //cumulative gain
if (i== events.length-1) {
assert(sum_n_i == N) : "Cumulative data fraction must be 1.0, but is " + (double)sum_n_i/N;
assert(sum_e_i == E) : "Cumulative capture rate must be 1.0, but is " + (double)sum_e_i/E;
if (!Double.isNaN(sum_lift)) assert(Math.abs(sum_lift - 1.0) < 1e-8) : "Cumulative lift must be 1.0, but is " + sum_lift;
assert(Math.abs((double)sum_e_i/sum_n_i - avg_response_rate) < 1e-8) : "Cumulative response rate must be " + avg_response_rate + ", but is " + (double)sum_e_i/sum_n_i;
}
}
return this.table = table;
}
// Compute Gains table via MRTask
public static class GainsLiftBuilder extends MRTask<GainsLiftBuilder> {
/* @OUT response_rates */
public final double[] response_rates() { return _response_rates; }
public final double avg_response_rate() { return _avg_response_rate; }
public final long[] events(){ return _events; }
public final long[] observations(){ return _observations; }
/* @IN quantiles/thresholds */
final private double[] _thresh;
private long[] _events;
private long[] _observations;
private long _avg_response;
private double _avg_response_rate;
private double[] _response_rates;
public GainsLiftBuilder(double[] thresh) {
_thresh = thresh.clone();
}
@Override public void map( Chunk ca, Chunk cp) { map(ca, cp, (Chunk)null); }
@Override public void map( Chunk ca, Chunk cp, Chunk cw) {
_events = new long[_thresh.length];
_observations = new long[_thresh.length];
_avg_response = 0;
final int len = Math.min(ca._len, cp._len);
for( int i=0; i < len; i++ ) {
if (ca.isNA(i)) continue;
final int a = (int)ca.at8(i);
if (a != 0 && a != 1) throw new IllegalArgumentException("Invalid values in actualLabels: must be binary (0 or 1).");
if (cp.isNA(i)) continue;
final double pr = cp.atd(i);
final double w = cw!=null?cw.atd(i):1;
perRow(pr, a, w);
}
}
public void perRow(double pr, int a, double w) {
if (w==0) return;
assert (!Double.isNaN(pr));
assert (!Double.isNaN(a));
assert (!Double.isNaN(w));
//for-loop is faster than binary search for small number of thresholds
for( int t=0; t < _thresh.length; t++ ) {
if (pr >= _thresh[t] && (t==0 || pr <_thresh[t-1])) {
_observations[t]+=w;
if (a == 1) _events[t]+=w;
break;
}
}
if (a == 1) _avg_response+=w;
}
@Override public void reduce(GainsLiftBuilder other) {
ArrayUtils.add(_events, other._events);
ArrayUtils.add(_observations, other._observations);
_avg_response += other._avg_response;
}
@Override public void postGlobal(){
_response_rates = new double[_thresh.length];
for (int i=0; i<_response_rates.length; ++i)
_response_rates[i] = _observations[i] == 0 ? 0 : (double) _events[i] / _observations[i];
_avg_response_rate = (double)_avg_response / ArrayUtils.sum(_observations);
}
}
}