package water.rapids.ast.prims.reducers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.vals.ValFrame;
import java.util.ArrayList;
import static org.junit.Assert.*;
/**
* Test the AstSumAxis.java class
*/
public class AstSumAxisTest extends TestUtil {
private static Vec vi1, vd1, vd2, vd3, vs1, vt1, vt2, vc1, vc2;
private static ArrayList<Frame> allFrames;
@BeforeClass public static void setup() {
stall_till_cloudsize(1);
vi1 = TestUtil.ivec(-1, -2, 0, 2, 1);
vd1 = TestUtil.dvec(1.5, 2.5, 3.5, 4.5, 8.0);
vd2 = TestUtil.dvec(0.2, 0.4, 0.6, 0.8, 1.0);
vd3 = TestUtil.dvec(1, 2, Double.NaN, 3, Double.NaN);
vs1 = TestUtil.svec("a", "b", "c", "d", "e");
vt1 = TestUtil.tvec(10000000, 10000020, 10000030, 10000040, 10000060);
vt2 = TestUtil.tvec(20000000, 20000020, 20000030, 20000040, 20000060);
vc1 = TestUtil.cvec(ar("N", "Y"), "Y", "N", "Y", "Y", "N");
vc2 = TestUtil.cvec("a", "c", "c", "b", "a");
allFrames = new ArrayList<>(10);
}
@AfterClass public static void teardown() {
for (Vec v : aro(vi1, vd1, vd2, vd3, vs1, vt1, vt2, vc1, vc2))
v.remove();
for (Frame f : allFrames)
f.delete();
}
//--------------------------------------------------------------------------------------------------------------------
// Tests
//--------------------------------------------------------------------------------------------------------------------
@Test public void testAstSumGeneralStructure() {
AstSumAxis a = new AstSumAxis();
String[] args = a.args();
assertEquals(3, args.length);
String example = a.example();
assertTrue(example.startsWith("(sumaxis "));
String description = a.description();
assertTrue("Description for AstSum is too short", description.length() > 100);
}
@Test public void testColumnwisesumWithoutNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("I", "D", "DD", "DN", "T", "S", "C"),
aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)
));
Val val1 = Rapids.exec("(sumaxis " + fr._key + " 0 0)");
assertTrue(val1 instanceof ValFrame);
Frame res = register(val1.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_TIME, Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(0.0, 20.0, 3.0, Double.NaN, 50000150.0, Double.NaN, Double.NaN), res);
}
@Test public void testColumnwiseSumWithNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("I", "D", "DD", "DN", "T", "S", "C"),
aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)
));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_TIME, Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(0.0, 20.0, 3.0, 6.0, 50000150.0, Double.NaN, Double.NaN), res);
}
@Test public void testColumnwisesumOnEmptyFrame() {
Frame fr = register(new Frame(Key.<Frame>make()));
Val val = Rapids.exec("(sumaxis " + fr._key + " 0 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals(res.numCols(), 0);
assertEquals(res.numRows(), 0);
}
@Test public void testColumnwisesumBinaryVec() {
assertTrue(vc1.isBinary() && !vc2.isBinary());
Frame fr = register(new Frame(Key.<Frame>make(), ar("C1", "C2"), aro(vc1, vc2)));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 0)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertArrayEquals(fr.names(), res.names());
assertArrayEquals(ar(Vec.T_NUM, Vec.T_NUM), res.types());
assertRowFrameEquals(ard(3.0, Double.NaN), res);
}
@Test public void testRowwisesumWithoutNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3"),
aro(vi1, vd1, vd2, vd3)
));
Val val = Rapids.exec("(sumaxis " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(1.7, 2.9, Double.NaN, 10.3, Double.NaN), res);
assertEquals("sum", res.name(0));
}
@Test public void testRowwisesumWithoutNaRmAndNonnumericColumn() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3", "s1"),
aro(vi1, vd1, vd2, vd3, vs1)
));
Val val = Rapids.exec("(sumaxis " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), res);
assertEquals("sum", res.name(0));
}
@Test public void testRowwisesumWithNaRm() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3", "s1"),
aro(vi1, vd1, vd2, vd3, vs1)
));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "sum", res.name(0));
assertEquals("Unexpected column type", Vec.T_NUM, res.types()[0]);
assertColFrameEquals(ard(1.7, 2.9, 4.1, 10.3, 10.0), res);
}
@Test public void testRowwisesumOnFrameWithTimeColumnsOnly() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("t1", "s", "t2"), aro(vt1, vs1, vt2)));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "sum", res.name(0));
assertEquals("Unexpected column type", Vec.T_TIME, res.types()[0]);
assertColFrameEquals(ard(30000000, 30000040, 30000060, 30000080, 30000120), res);
}
@Test public void testRowwisesumOnFrameWithTimeandNumericColumn() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("t1", "i1"), aro(vt1, vi1)));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(-1, -2, 0, 2, 1), res);
}
@Test public void testRowwisesumOnEmptyFrame() {
Frame fr = register(new Frame(Key.<Frame>make()));
Val val = Rapids.exec("(sumaxis " + fr._key + " 0 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals(res.numCols(), 0);
assertEquals(res.numRows(), 0);
}
@Test public void testRowwisesumOnFrameWithNonnumericColumnsOnly() {
Frame fr = register(new Frame(Key.<Frame>make(), ar("c1", "s1"), aro(vc2, vs1)));
Val val = Rapids.exec("(sumaxis " + fr._key + " 1 1)");
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertEquals("Unexpected column name", "sum", res.name(0));
assertEquals("Unexpected column type", Vec.T_NUM, res.types()[0]);
assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), res);
}
@Test public void testBadFirstArgument() {
try {
Rapids.exec("(sumaxis " + vi1._key + " 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
try {
Rapids.exec("(sum hello 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
try {
Rapids.exec("(sum 2 1 0)");
fail();
} catch (IllegalArgumentException ignored) {}
}
@Test public void testValRowArgument() {
Frame fr = register(new Frame(Key.<Frame>make(),
ar("i1", "d1", "d2", "d3"),
aro(vi1, vd1, vd2, vd3)
));
Val val = Rapids.exec("(apply " + fr._key + " 1 {x . (sumaxis x 1)})"); // skip NAs
assertTrue(val instanceof ValFrame);
Frame res = register(val.getFrame());
assertColFrameEquals(ard(1.7, 2.9, 4.1, 10.3, 10.0), res);
Val val2 = Rapids.exec("(apply " + fr._key + " 1 {x . (sumaxis x 0)})"); // do not skip NAs
assertTrue(val2 instanceof ValFrame);
Frame res2 = register(val2.getFrame());
assertColFrameEquals(ard(1.7, 2.9, Double.NaN, 10.3, Double.NaN), res2);
}
//--------------------------------------------------------------------------------------------------------------------
// Helpers
//--------------------------------------------------------------------------------------------------------------------
private static void assertRowFrameEquals(double[] expected, Frame actual) {
assertEquals(1, actual.numRows());
assertEquals(expected.length, actual.numCols());
for (int i = 0; i < expected.length; i++) {
assertEquals("Wrong sum in column " + actual.name(i), expected[i], actual.vec(i).at(0), 1e-8);
}
}
private static void assertColFrameEquals(double[] expected, Frame actual) {
assertEquals(1, actual.numCols());
assertEquals(expected.length, actual.numRows());
for (int i = 0; i < expected.length; i++) {
assertEquals("Wrong sum in row " + i, expected[i], actual.vec(0).at(i), 1e-8);
}
}
private static Frame register(Frame f) {
if (f._key != null) DKV.put(f._key, f);
allFrames.add(f);
return f;
}
}