package hex;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.TestUtil;
import water.util.Log;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
public class ScoreKeeperTest extends TestUtil {
@BeforeClass
public static void stall() { stall_till_cloudsize(1); }
// implement early stopping logic from scratch
static boolean stopEarly(double[] val, int k, double tolerance, boolean moreIsBetter, boolean verbose) {
if (val.length-1 < 2*k) return false; //need 2k scoring events (+1 to skip the very first one, which might be full of NaNs)
double[] moving_avg = new double[k+1]; //one moving avg for the last k+1 scoring events (1 is reference, k consecutive attempts to improve)
// compute moving average(s)
for (int i=0;i<moving_avg.length;++i) {
moving_avg[i]=0;
int startidx=val.length-2*k+i;
for (int j=0;j<k;++j)
moving_avg[i]+=val[startidx+j];
moving_avg[i]/=k;
}
if (verbose) Log.info("JUnit: moving averages: " + Arrays.toString(moving_avg));
// check if any of the moving averages is better than the reference (by at least tolerance relative improvement)
double ref = moving_avg[0];
boolean improved = false;
for (int i=1;i<moving_avg.length;++i) {
if (moreIsBetter)
improved |= (moving_avg[i] > ref*(1+tolerance));
else
improved |= (moving_avg[i] < ref*(1-tolerance));
if (improved && verbose)
Log.info("JUnit: improved from " + ref + " to " + moving_avg[i] + " by at least " + tolerance + " relative tolerance");
}
if (improved) {
if (verbose) Log.info("JUnit: Still improving.");
return false;
}
else {
if (verbose) Log.info("JUnit: Stopped.");
return true;
}
}
// helper
private static ScoreKeeper[] fillScoreKeeperArray(double[] values, boolean moreIsBetter) {
ScoreKeeper[] sk = new ScoreKeeper[values.length];
for (int i=0;i<values.length;++i) {
sk[i] = new ScoreKeeper();
if (moreIsBetter)
sk[i]._AUC = values[i];
else
sk[i]._logloss = values[i];
}
return sk;
}
@Test
public void testConvergenceScoringHistory() {
Random rng = new Random(0xC0FFEE);
int count=0;
while (count++ < 100) {
boolean moreIsBetter = rng.nextBoolean();
ScoreKeeper.StoppingMetric metric = moreIsBetter ? ScoreKeeper.StoppingMetric.AUC : ScoreKeeper.StoppingMetric.logloss;
double tol = rng.nextFloat()*1e-1;
int N = 5+rng.nextInt(10);
double[] values = new double[N];
for (int i=0;i<N;++i) {
//random walk around linearly increasing (or decreasing) values around 20 (not around 0 to avoid zero-crossing issues)
values[i] = (moreIsBetter ? 10 + (double) i / N : 10 - (double) i / N) + rng.nextGaussian() * 0.33;
}
ScoreKeeper[] sk = fillScoreKeeperArray(values, moreIsBetter);
Log.info();
Log.info("series: " + Arrays.toString(values));
Log.info("moreIsBetter: " + moreIsBetter);
Log.info("relative tolerance: " + tol);
for (int k=values.length-1;k>0;k--) {
boolean c = stopEarly(values, k, tol, moreIsBetter, false /*verbose*/);
boolean d = ScoreKeeper.stopEarly(sk, k, true /*classification*/, metric, tol, "JUnit's", false /*verbose*/);
// for debugging
// Log.info("Checking for stopping condition with k=" + k + ": " + c + " " + d);
if (c || d) Log.info("Stopped for k=" + k);
// if (!c && !d && k==1) Log.info("Still improving.");
// if (d!=c) {
// Log.info("k="+ k);
// Log.info("tol="+ tol);
// Log.info("moreIsBetter="+ moreIsBetter);
// stopEarly(values, k, tol, moreIsBetter, true /*verbose*/);
// ScoreKeeper.stopEarly(sk, k, true /*classification*/, metric, tol, "JUnit", true /*verbose*/);
// }
Assert.assertTrue("For k="+k+", JUnit: " + c + ", ScoreKeeper: " + d, c == d);
}
}
}
@Test
public void testGridSearch() {
Random rng = new Random(0xDECAF);
int count=0;
while (count++<100) {
final boolean moreIsBetter = rng.nextBoolean();
Double[] Dvalues;
double tol;
if (true) {
// option 1: random values
int N = 5 + rng.nextInt(10);
tol = rng.nextDouble() * 0.1;
Dvalues = new Double[N];
for (int i = 0; i < N; ++i)
Dvalues[i] = 10 + rng.nextDouble(); //every grid search models has a random score between 10 and 11 (not around 0 to avoid zero-crossing issues)
} else {
// option 2: manual values
tol = 0;
Dvalues = new Double[]{0.91, 0.92, 0.95, 0.94, 0.93}; //in order of occurrence
}
// sort to get "leaderboard"
Arrays.sort(Dvalues, new Comparator<Double>() {
@Override
public int compare(Double o1, Double o2) {
int val = o1.doubleValue() < o2.doubleValue() ? 1 : o1.doubleValue()==o2.doubleValue() ? 0 : -1;
if (moreIsBetter) val=-val;
return val;
}
});
double[] values = new double[Dvalues.length];
for (int i=0;i<values.length;++i) values[i] = Dvalues[i].doubleValue();
Log.info("Sorted values (leaderboard) - rightmost is best: " + Arrays.toString(values));
for (int k=1;k<values.length;++k) {
Log.info("Testing k=" + k);
ScoreKeeper.StoppingMetric metric = moreIsBetter ? ScoreKeeper.StoppingMetric.AUC : ScoreKeeper.StoppingMetric.logloss;
ScoreKeeper[] sk = fillScoreKeeperArray(values, moreIsBetter);
boolean c = stopEarly(values, k, tol, moreIsBetter, true /*verbose*/);
boolean d = ScoreKeeper.stopEarly(sk, k, true /*classification*/, metric, tol, "JUnit's", true /*verbose*/);
Assert.assertTrue("For k=" + k + ", JUnit: " + c + ", ScoreKeeper: " + d, c == d);
}
}
}
}