package com.livingsocial.hive.udf; import org.apache.commons.math.MathException; import org.apache.commons.math.distribution.NormalDistribution; import org.apache.commons.math.distribution.NormalDistributionImpl; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.io.LongWritable; @Description( name = "p_value", value = "_FUNC_(double controlAvg, double controlStddev, long controlSize, double treatmentAvg, double treatmentStddev, long treatmentSize) - Returns the p_value for the control and treatment groups based on the passed in stats", extended = "Example:\n" + " > SELECT p_value(avg(if(control=1, revenue, 0)), stddev_pop(if(control=1, revenue, 0)), sum(if(control=1, 1, 0)), \n" + " avg(if(control=0, revenue, 0)), stddev_pop(if(control=0, revenue, 0)), sum(if(control=0, 1, 0))) \n" + " FROM revenue_table;\n" + "\n" + " Alternate format: p_value(critical_value). This skips the rest and just does a t-dist lookup" ) public class ZTest extends UDF { private static NormalDistribution distribution = new NormalDistributionImpl(); public double pval(double val){ try { return 2 * (1 - distribution.cumulativeProbability(val)); } catch (MathException e) { throw new RuntimeException(e); } } private double criticalValue(double controlAvg, double controlStddev, long controlSize, double treatmentAvg, double treatmentStddev, long treatmentSize) { return Math.abs(treatmentAvg - controlAvg) / Math.sqrt( (treatmentStddev*treatmentStddev/treatmentSize) + (controlStddev*controlStddev/controlSize)); } private double pval(final double controlAvg, final double controlStddev, final long controlSize, final double treatmentAvg, final double treatmentStddev, final long treatmentSize) { double critValue = criticalValue(controlAvg, controlStddev, controlSize, treatmentAvg, treatmentStddev, treatmentSize); return pval(critValue); } public DoubleWritable evaluate(final DoubleWritable criticalValue) { if (criticalValue == null) return null; double val = criticalValue.get(); return new DoubleWritable(pval(val)); } // For now ignore the degrees of freedom and use the infinite degrees model public DoubleWritable evaluate(final DoubleWritable criticalValue, final LongWritable degreesOfFreedom) { return evaluate(criticalValue); } public DoubleWritable evaluate(final DoubleWritable controlAvg, final DoubleWritable controlStddev, final LongWritable controlSize, final DoubleWritable treatmentAvg, final DoubleWritable treatmentStddev, final LongWritable treatmentSize) { if( controlAvg == null || controlSize == null || controlStddev == null || treatmentAvg == null || treatmentSize == null || treatmentStddev == null) { return null; } return new DoubleWritable(pval(controlAvg.get(), controlStddev.get(), controlSize.get(), treatmentAvg.get(), treatmentStddev.get(), treatmentSize.get())); } }