package hex; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.Key; import water.MRTask; import water.TestUtil; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.InteractionWrappedVec; import water.fvec.Vec; // test cases: // skipMissing = TRUE/FALSE // useAllLevels = TRUE/FALSE // limit enums // (dont) standardize predictor columns // data info tests with interactions public class DataInfoTest extends TestUtil { @BeforeClass static public void setup() { stall_till_cloudsize(1); } @Test public void testAirlines1() { // just test that it works at all Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip"); try { DataInfo dinfo = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(8),fr.name(16),fr.name(2)} // interactions ); dinfo.dropInteractions(); dinfo.remove(); } finally { fr.delete(); } } @Test public void testAirlines2() { Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip"); try { Frame interactions = Model.makeInteractions(fr, false, Model.InteractionPair.generatePairwiseInteractionsFromList(8, 16, 2), true, true,true); int len=0; for(Vec v: interactions.vecs()) len += ((InteractionWrappedVec)v).expandedLength(); interactions.delete(); Assert.assertTrue(len==290+132+10); DataInfo dinfo__noInteractions = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold null ); System.out.println(dinfo__noInteractions.fullN()); System.out.println(dinfo__noInteractions.numNums()); DataInfo dinfo__withInteractions = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(8),fr.name(16),fr.name(2)} // interactions ); System.out.println(dinfo__withInteractions.fullN()); Assert.assertTrue(dinfo__withInteractions.fullN() == dinfo__noInteractions.fullN() + len); dinfo__withInteractions.dropInteractions(); dinfo__noInteractions.remove(); dinfo__withInteractions.remove(); } finally { fr.delete(); } } @Test public void testAirlines3() { Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip"); try { Frame interactions = Model.makeInteractions(fr, false, Model.InteractionPair.generatePairwiseInteractionsFromList(8, 16, 2), false, true, true); int len=0; for(Vec v: interactions.vecs()) len += ((InteractionWrappedVec)v).expandedLength(); interactions.delete(); Assert.assertTrue(len==426); DataInfo dinfo__noInteractions = new DataInfo( fr.clone(), // train null, // valid 1, // num responses false, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold null ); System.out.println(dinfo__noInteractions.fullN()); DataInfo dinfo__withInteractions = new DataInfo( fr.clone(), // train null, // valid 1, // num responses false, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(8),fr.name(16),fr.name(2)} // interactions ); System.out.println(dinfo__withInteractions.fullN()); Assert.assertTrue(dinfo__withInteractions.fullN() == dinfo__noInteractions.fullN() + len); dinfo__withInteractions.dropInteractions(); dinfo__noInteractions.remove(); dinfo__withInteractions.remove(); } finally { fr.delete(); } } @Test public void testIris1() { // test that getting sparseRows and denseRows produce the same results Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/iris/iris_wheader.csv"); fr.swap(1,4); Model.InteractionPair[] ips = Model.InteractionPair.generatePairwiseInteractionsFromList(0, 1); DataInfo di=null; try { di = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.NONE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(0),fr.name(1)} // interactions ); checker(di,false); } finally { fr.delete(); if( di!=null ) { di.dropInteractions(); di.remove(); } } } @Test public void testIris2() { // test that getting sparseRows and denseRows produce the same results Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/iris/iris_wheader.csv"); fr.swap(1,4); Model.InteractionPair[] ips = Model.InteractionPair.generatePairwiseInteractionsFromList(0, 1); DataInfo di=null; try { di = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(0),fr.name(1)} // interactions ); checker(di,true); } finally { fr.delete(); if( di!=null ) { di.dropInteractions(); di.remove(); } } } @Test public void testIris3() { // test that getting sparseRows and denseRows produce the same results Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/iris/iris_wheader.csv"); fr.swap(2,4); Model.InteractionPair[] ips = Model.InteractionPair.generatePairwiseInteractionsFromList(0, 1, 2, 3); DataInfo di=null; try { di = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(0),fr.name(1),fr.name(2),fr.name(3)} // interactions ); checker(di,true); } finally { fr.delete(); if( di!=null ) { di.dropInteractions(); di.remove(); } } } @Test public void testAirlines4() { Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip"); Model.InteractionPair[] ips = Model.InteractionPair.generatePairwiseInteractionsFromList(8,16,2); DataInfo di=null; try { di = new DataInfo( fr.clone(), // train null, // valid 1, // num responses true, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(8),fr.name(16),fr.name(2)} // interactions ); checker(di,true); } finally { fr.delete(); if( di!=null ) { di.dropInteractions(); di.remove(); } } } @Test public void testAirlines5() { Frame fr = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip"); Model.InteractionPair[] ips = Model.InteractionPair.generatePairwiseInteractionsFromList(8,16,2); DataInfo di=null; try { di = new DataInfo( fr.clone(), // train null, // valid 1, // num responses false, // use all factor levels DataInfo.TransformType.STANDARDIZE, // predictor transform DataInfo.TransformType.NONE, // response transform true, // skip missing false, // impute missing false, // missing bucket false, // weight false, // offset false, // fold new String[]{fr.name(8),fr.name(16),fr.name(2)} // interactions ); checker(di,true); } finally { fr.delete(); if( di!=null ) { di.dropInteractions(); di.remove(); } } } // @Test public void personalChecker() { // final Frame gold = parse_test_file(Key.make("gold"), "/Users/spencer/Desktop/ffff.csv"); // Frame fr = parse_test_file(Key.make("a.hex"), "/Users/spencer/Desktop/iris.csv"); // fr.swap(3,4); // DataInfo di0=null; // try { // di0 = new DataInfo( // fr.clone(), // train // null, // valid // 1, // num responses // false, // use all factor levels // DataInfo.TransformType.STANDARDIZE, // predictor transform // DataInfo.TransformType.NONE, // response transform // true, // skip missing // false, // impute missing // false, // missing bucket // false, // weight // false, // offset // false, // fold // new String[]{"Species", "Sepal.Length", "Petal.Length"} // interactions // ); // final DataInfo di=di0; // new MRTask() { // @Override public void map(Chunk[] cs) { // DataInfo.Row[] sparseRows = di.extractSparseRows(cs); // for(int i=0;i<cs[0]._len;++i) { //// di.extractDenseRow(cs, i, r); // DataInfo.Row r = sparseRows[i]; // int idx=1; // for (int j = di.numStart(); j < di.fullN(); ++j) { // double goldValue = gold.vec(idx++).at(i+cs[0].start()); // double thisValue = r.get(j) - (di._normSub[j - di.numStart()] * di._normMul[j-di.numStart()]); // double diff = Math.abs(goldValue - thisValue); // if( diff > 1e-12 ) // throw new RuntimeException("bonk"); // } // } // } // }.doAll(di0._adaptedFrame); // } finally { // fr.delete(); // gold.delete(); // if( di0!=null ) { // di0.dropInteractions(); // di0.remove(); // } // } // } private static void printVals(DataInfo di, DataInfo.Row denseRow, DataInfo.Row sparseRow) { System.out.println("col|dense|sparse|sparseScaled"); double sparseScaled; String line; for(int i=0;i<di.fullN();++i) { sparseScaled = sparseRow.get(i); if( i>=di.numStart() ) sparseScaled -= (di._normSub[i - di.numStart()] * di._normMul[i-di.numStart()]); line = i+"|"+denseRow.get(i)+"|"+sparseRow.get(i)+"|"+sparseScaled; if( Math.abs(denseRow.get(i)-sparseScaled) > 1e-14 ) System.out.println(">" + line + "<"); } } private static void checker(final DataInfo di, final boolean standardize) { new MRTask() { @Override public void map(Chunk[] cs) { DataInfo.Row[] sparseRows = di.extractSparseRows(cs); DataInfo.Row r = di.newDenseRow(); for(int i=0;i<cs[0]._len;++i) { di.extractDenseRow(cs, i, r); for (int j = 0; j < di.fullN(); ++j) { double sparseDoubleScaled = sparseRows[i].get(j); // extracting sparse rows does not do the full scaling!! if( j>=di.numStart() ) { // finish scaling the sparse value sparseDoubleScaled -= (standardize?(di._normSub[j - di.numStart()] * di._normMul[j-di.numStart()]):0); } if( r.isBad() || sparseRows[i].isBad() ) { if( sparseRows[i].isBad() && r.isBad() ) continue; // both bad OK throw new RuntimeException("dense row was "+(r.isBad()?"bad":"not bad") + "; but sparse row was "+(sparseRows[i].isBad()?"bad":"not bad")); } if( Math.abs(r.get(j)-sparseDoubleScaled) > 1e-14 ) { printVals(di,r,sparseRows[i]); throw new RuntimeException("Row mismatch on row " + i); } } } } }.doAll(di._adaptedFrame); } }