package water.api;
import water.Func;
import water.MRTask2;
import water.UKV;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.Utils;
/* Compute the Gains and Lift Table for binary classifier */
public class GainsLiftTable extends Func {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
public static final String DOC_GET = "Gains/Lift Table";
@API(help = "", required = true, filter = Default.class, json=true)
public Frame actual;
@API(help="Column of the actual results", required=true, filter=actualVecSelect.class, json=true)
public Vec vactual;
class actualVecSelect extends VecClassSelect { actualVecSelect() { super("actual"); } }
@API(help = "", required = true, filter = Default.class, json=true)
public Frame predict;
@API(help="Column of the predicted results", required=true, filter=predictVecSelect.class, json=true)
public Vec vpredict;
class predictVecSelect extends VecClassSelect { predictVecSelect() { super("predict"); } }
@API(help = "The number of rows in the gains table", required = false, filter = Default.class, json = true)
public int groups = 10;
// helper - contains the probability thresholds for each of the groups
double[] thresholds;
// Results (Output)
@API(help="Response rates", json=true)
public float[] response_rates;
@API(help="Average response rate", json=true)
public float avg_response_rate;
@API(help="Positive Responses Per Group", json=true)
public long[] positive_responses;
@Override protected void init() throws IllegalArgumentException {
// Input handling
if( vactual==null || vpredict==null )
throw new IllegalArgumentException("Missing vactual or vpredict!");
if (vactual.length() != vpredict.length())
throw new IllegalArgumentException("Both arguments must have the same length ("+vactual.length()+"!="+vpredict.length()+")!");
if (!vactual.isInt())
throw new IllegalArgumentException("Actual column must be integer class labels!");
if (vactual.cardinality() != -1 && vactual.cardinality() != 2)
throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + vactual.cardinality() + "!");
if (vpredict.isEnum())
throw new IllegalArgumentException("vpredict cannot be class labels, expect probabilities.");
}
public GainsLiftTable() {}
public GainsLiftTable(float[] response_rates, float avg_response_rate) {
this.response_rates = response_rates;
this.avg_response_rate = avg_response_rate;
}
@Override protected void execImpl() {
Vec va = null, vp;
try {
va = vactual.toEnum(); // always returns TransfVec
vp = vpredict;
// The vectors are from different groups => align them, but properly delete it after computation
if (!va.group().equals(vp.group())) {
vp = va.align(vp);
}
// compute thresholds for each quantile
{
thresholds = new double[groups];
for (int i=0; i<groups; ++i) {
QuantilesPage q = new QuantilesPage();
q.source_key = predict;
q.column = vpredict;
q.quantile = (groups-i-1.) / groups;
q.invoke();
thresholds[i] = q.result;
}
if (Utils.minValue(thresholds) < 0) throw new IllegalArgumentException("Minimum propability cannot be negative.");
if (Utils.maxValue(thresholds) > 1) throw new IllegalArgumentException("Maximum probability cannot be greater than 1.");
// Now compute the GainsTask
GainsTask gt = new GainsTask(thresholds, va.length());
gt.doAll(va, vp);
response_rates = gt.response_rates();
avg_response_rate = gt.avg_response_rate();
positive_responses = gt.responses();
}
} catch (Throwable t) {
// do nothing
} finally { // Delete adaptation vectors
if (va!=null) UKV.remove(va._key);
}
StringBuilder sb = new StringBuilder();
toASCII(sb);
Log.info(sb);
}
@Override public boolean toHTML( StringBuilder sb ) {
if (response_rates == null) return false;
DocGen.HTML.arrayHead(sb);
sb.append("<a href=\"http://books.google.com/books?id=-JwptfFItaoC&pg=PA318&lpg=PA319&source=bl&ots=_S6fJI5Wds&sig=Uvff-MosTE7CR4e8LdE8TdJvo44&hl=en&sa=X&ei=b3EcVMnHB6T2iwK3koC4Cw&ved=0CF0Q6AEwBw#v=onepage&q&f=false\">"
+ "Gains/Lift Table Reference</a></h4>");
// Sum up predicted & actuals
sb.append("<tr class='warning' style='min-width:60px'>");
sb.append("<th>Quantile</th><th>Response rate</th><th>Lift</th><th>Cumulative lift</th>");
sb.append("</tr>");
float cumulativelift = 0;
for( int i=0; i<groups; i++ ) {
sb.append("<tr>");
sb.append("<td>").append(Utils.formatPct((i + 1.) / groups)).append("</td>");
sb.append("<td>").append(Utils.formatPct(response_rates[i])).append("</td>");
final float lift = response_rates[i]/ avg_response_rate;
cumulativelift += lift/groups;
sb.append("<td>").append(lift).append("</td>");
sb.append("<td>").append(Utils.formatPct(cumulativelift)).append("</td>");
}
sb.append("<tr style='min-width:60px'><th>Total</th>");
sb.append("<td>").append(Utils.formatPct(avg_response_rate)).append("</td>");
sb.append("<td>").append(1.0).append("</td>");
sb.append("<td></td>");
DocGen.HTML.arrayTail(sb);
return true;
}
public void toASCII( StringBuilder sb ) {
if (response_rates == null) return;
// Sum up predicted & actuals
sb.append("Quantile Response rate Lift Cumulative lift\n");
float cumulativelift = 0;
for( int i=0; i<groups; i++ ) {
sb.append(Utils.formatPct((i + 1.) / groups));
sb.append(" ").append(Utils.formatPct(response_rates[i])).append(" ");
final float lift = response_rates[i]/ avg_response_rate;
cumulativelift += lift/groups;
sb.append(" ").append(lift).append(" ");
sb.append(" ").append(Utils.formatPct(cumulativelift)).append("\n");
}
sb.append("Total ");
sb.append(" ").append(Utils.formatPct(avg_response_rate)).append(" ");
sb.append(" ").append(1.0).append(" ");
sb.append(" \n");
}
// Compute Gains table via MRTask2
private static class GainsTask extends MRTask2<GainsTask> {
/* @OUT response_rates */
public final float[] response_rates() { return _response_rates; }
public final float avg_response_rate() { return _avg_response_rate; }
public final long[] responses(){ return _responses; }
/* @IN total count of events */ final private double[] _thresh;
final private long _count;
private long[] _responses;
private long _avg_response;
private float _avg_response_rate;
private float[] _response_rates;
GainsTask(double[] thresh, long count) {
_thresh = thresh.clone();
_count = count;
}
@Override public void map( Chunk ca, Chunk cp ) {
_responses = new long[_thresh.length];
_avg_response = 0;
final int len = Math.min(ca._len, cp._len);
for( int i=0; i < len; i++ ) {
if (ca.isNA0(i)) continue;
final int a = (int)ca.at80(i);
if (a != 0 && a != 1) throw new IllegalArgumentException("Invalid values in vactual: must be binary (0 or 1).");
if (cp.isNA0(i)) continue;
final double pr = cp.at0(i);
for( int t=0; t < _thresh.length; t++ ) {
// count number of positive responses in bucket given by two thresholds
if (pr >= _thresh[t] && (t == 0 || pr < _thresh[t-1]) && a == 1) _responses[t]++;
}
if (a == 1) _avg_response++;
}
}
@Override public void reduce( GainsTask other ) {
for( int i=0; i<_responses.length; ++i) {
_responses[i] += other._responses[i];
}
_avg_response += other._avg_response;
}
@Override public void postGlobal(){
_response_rates = new float[_thresh.length];
for (int i=0; i<_response_rates.length; ++i) {
_response_rates[i] = (float) _responses[i];
}
Utils.div(_response_rates, (float)_count/_thresh.length);
for (int i=0; i<_response_rates.length; ++i) {
// spill over to next bucket - needed due to tie breaking in quantiles
if(_response_rates[i] > 1) {
_response_rates[i+1] += (_response_rates[i]-1);
_response_rates[i] -= (_response_rates[i]-1);
}
}
_avg_response_rate = (float)_avg_response / _count;
}
}
}