package water.rapids.ast.prims.reducers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValRow;
import java.util.ArrayList;
import static org.junit.Assert.*;
/**
* Test the AstMean.java class
*/
public class AstMeanTest extends TestUtil {
private static Vec vi1, vd1, vd2, vd3, vs1, vt1, vt2, vc1, vc2;
private static ArrayList<Frame> allFrames;
@BeforeClass public static void setup() {
stall_till_cloudsize(1);
vi1 = TestUtil.ivec(-1, -2, 0, 2, 1);
vd1 = TestUtil.dvec(1.5, 2.5, 3.5, 4.5, 8.0);
vd2 = TestUtil.dvec(0.2, 0.4, 0.6, 0.8, 1.0);
vd3 = TestUtil.dvec(1, 2, Double.NaN, 3, Double.NaN);
vs1 = TestUtil.svec("a", "b", "c", "d", "e");
vt1 = TestUtil.tvec(10000000, 10000020, 10000030, 10000040, 10000060);
vt2 = TestUtil.tvec(20000000, 20000020, 20000030, 20000040, 20000060);
vc1 = TestUtil.cvec(ar("N", "Y"), "Y", "N", "Y", "Y", "N");
vc2 = TestUtil.cvec("a", "c", "c", "b", "a");
allFrames = new ArrayList<>(10);
}
@AfterClass public static void teardown() {
for (Vec v : aro(vi1, vd1, vd2, vd3, vs1, vt1, vt2, vc1, vc2))
v.remove();
for (Frame f : allFrames)
f.delete();
}
//--------------------------------------------------------------------------------------------------------------------
// Tests
//--------------------------------------------------------------------------------------------------------------------
@Test public void testAstMeanGeneralStructure() {
AstMean a = new AstMean();
String[] args = a.args();
assertEquals(3, args.length);
String example = a.example();
assertTrue(example.startsWith("(mean "));
String description = a.description();
assertTrue("Description for AstMean is too short", description.length() > 100);
}
@Test public void testColumnwiseMeanWithoutNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("I", "D", "DD", "DN", "T", "S", "C"),
aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)
));
Val val1 = Rapids.exec("(mean " + fr._key + " 0 0)");
assertTrue(val1 instanceof ValFrame);
Frame res = register(val1.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_TIME, Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(0.0, 4.0, 0.6, Double.NaN, 10000030.0, Double.NaN, Double.NaN), res);
}
@Test public void testColumnwiseMeanWithNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("I", "D", "DD", "DN", "T", "S", "C"),
aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)
));
Val val = Rapids.exec("(mean " + fr._key + " 1 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_TIME, Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(0.0, 4.0, 0.6, 2.0, 10000030.0, Double.NaN, Double.NaN), res);
}
@Test public void testColumnwiseMeanOnEmptyFrame() {
Frame fr = register(new Frame(Key.<Frame>make()));
Val val = Rapids.exec("(mean " + fr._key + " 0 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals(res.numCols(), 0);
assertEquals(res.numRows(), 0);
}
@Test public void testColumnwiseMeanBinaryVec() {
assertTrue(vc1.isBinary() && !vc2.isBinary());
Frame fr = register(new Frame(Key.<Frame>make(), ar("C1", "C2"), aro(vc1, vc2)));
Val val = Rapids.exec("(mean " + fr._key + " 1 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(0.6, Double.NaN), res);
}
@Test public void testRowwiseMeanWithoutNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3"),
aro(vi1, vd1, vd2, vd3)
));
Val val = Rapids.exec("(mean " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(1.7/4, 2.9/4, Double.NaN, 10.3/4, Double.NaN), res);
assertEquals("mean", res.name(0));
}
@Test public void testRowwiseMeanWithoutNaRmAndNonnumericColumn() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3", "s1"),
aro(vi1, vd1, vd2, vd3, vs1)
));
Val val = Rapids.exec("(mean " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), res);
assertEquals("mean", res.name(0));
}
@Test public void testRowwiseMeanWithNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3", "s1"),
aro(vi1, vd1, vd2, vd3, vs1)
));
Val val = Rapids.exec("(mean " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "mean", res.name(0));
assertEquals("Unexpected column type", Vec.T_NUM, res.types()[0]);
assertColFrameEquals(ard(1.7/4, 2.9/4, 4.1/3, 10.3/4, 10.0/3), res);
}
@Test public void testRowwiseMeanOnFrameWithTimeColumnsOnly() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("t1", "s", "t2"), aro(vt1, vs1, vt2)));
Val val = Rapids.exec("(mean " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "mean", res.name(0));
assertEquals("Unexpected column type", Vec.T_TIME, res.types()[0]);
assertColFrameEquals(ard(15000000, 15000020, 15000030, 15000040, 15000060), res);
}
@Test public void testRowwiseMeanOnFrameWithTimeAndNumericColumn() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("t1", "i1"), aro(vt1, vi1)));
Val val = Rapids.exec("(mean " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(-1, -2, 0, 2, 1), res);
}
@Test public void testRowwiseMeanOnEmptyFrame() {
Frame fr = register(new Frame(Key.<Frame>make()));
Val val = Rapids.exec("(mean " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals(res.numCols(), 0);
assertEquals(res.numRows(), 0);
}
@Test public void testRowwiseMeanOnFrameWithNonnumericColumnsOnly() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("c1", "s1"), aro(vc2, vs1)));
Val val = Rapids.exec("(mean " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "mean", res.name(0));
assertEquals("Unexpected column type", Vec.T_NUM, res.types()[0]);
assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), res);
}
@Test public void testBadFirstArgument() {
try {
Rapids.exec("(mean " + vi1._key + " 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
try {
Rapids.exec("(mean hello 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
try {
Rapids.exec("(mean 2 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
}
@Test public void testValRowArgument() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3"),
aro(vi1, vd1, vd2, vd3)
));
Val val = Rapids.exec("(apply " + fr._key + " 1 {x . (mean x 1)})"); // skip NAs
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(1.7/4, 2.9/4, 4.1/3, 10.3/4, 10.0/3), res);
Val val2 = Rapids.exec("(apply " + fr._key + " 1 {x . (mean x 0)})"); // do not skip NAs
assertTrue(val2 instanceof ValFrame);
Frame res2 = register(val2.getFrame());
assertColFrameEquals(ard(1.7/4, 2.9/4, Double.NaN, 10.3/4, Double.NaN), res2);
}
//--------------------------------------------------------------------------------------------------------------------
// Helpers
//--------------------------------------------------------------------------------------------------------------------
private static void assertRowFrameEquals(double[] expected, Frame actual) {
assertEquals(1, actual.numRows());
assertEquals(expected.length, actual.numCols());
for (int i = 0; i < expected.length; i++) {
assertEquals("Wrong average in column " + actual.name(i), expected[i], actual.vec(i).at(0), 1e-8);
}
}
private static void assertColFrameEquals(double[] expected, Frame actual) {
assertEquals(1, actual.numCols());
assertEquals(expected.length, actual.numRows());
for (int i = 0; i < expected.length; i++) {
assertEquals("Wrong average in row " + i, expected[i], actual.vec(0).at(i), 1e-8);
}
}
private static Frame register(Frame f) {
if (f._key != null) DKV.put(f._key, f);
allFrames.add(f);
return f;
}
}