package hex;
import static hex.ModelMetricsMultinomial.updateHits;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import static water.TestUtil.stall_till_cloudsize;
import java.util.Arrays;
public class HitRatioTest {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test
public void testHits() {
double[] hits = new double[4];
double[] pred_dist;
int actual_label;
// No ties
//top 1
Arrays.fill(hits, 0);
actual_label = 0; pred_dist = new double[]{0,.4f,.1f,.2f,.3f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(Arrays.equals(hits, new double[]{1, 0, 0, 0}));
// top-2
Arrays.fill(hits, 0);
actual_label = 3; pred_dist = new double[]{0,.4f,.1f,.2f,.3f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(Arrays.equals(hits, new double[]{0, 1, 0, 0}));
// top-2
Arrays.fill(hits, 0);
actual_label = 0; pred_dist = new double[]{3,.3f,.2f,.1f,.4f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(Arrays.equals(hits, new double[]{0, 1, 0, 0}));
// top-3
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{3,.3f,.2f,.1f,.4f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(Arrays.equals(hits, new double[]{0, 0, 1, 0}));
// top-4
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{0,.4f,.1f,.3f,.2f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(Arrays.equals(hits, new double[]{0, 0, 0, 1}));
// 2 Ties
// actual 1, predicted 0, but tie-break -> top-2 is hit
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{0,.3f,.3f,.2f,.2f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 1, 0, 0})
);
// top-2 or top-3
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{2,.3f,.3f,.35f,.05f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 1, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 1, 0})
);
// top-3 or top-4
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{0,.3f,.1f,.2f,.1f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 0, 1, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 0, 1})
);
// 3 Ties
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{3,.3f,.3f,.1f,.3f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{1, 0, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 1, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 1, 0})
);
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{3,.1f,.1f,.1f,.7f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 1, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 1, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 0, 1})
);
Arrays.fill(hits, 0);
actual_label = 2; pred_dist = new double[]{3,.1f,.1f,.1f,.7f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 1, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 1, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 0, 1})
);
// 4 Ties
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{2,.25f,.25f,.25f,.25f}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{1, 0, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 1, 0, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 1, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 0, 1})
);
// more predictions than K=4, and actual is outside of top-K
Arrays.fill(hits, 0);
actual_label = 1; pred_dist = new double[]{4,.15,0.1,0.1,.25,.3,.15,0.2,0.2}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 0, 0, 0})
);
// more predictions than K=4, and actual is just inside of top-K
Arrays.fill(hits, 0);
actual_label = 6; pred_dist = new double[]{4,.15,0.1,0.1,.25,.3,.15,0.2,0.2}; updateHits(1,actual_label, pred_dist, hits);
Assert.assertTrue(
Arrays.equals(hits, new double[]{0, 0, 1, 0}) ||
Arrays.equals(hits, new double[]{0, 0, 0, 1})
);
}
}