package hex;
import hex.glm.GLMModel;
import hex.splitframe.ShuffleSplitFrame;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import java.util.Random;
// data info tests with interactions
public class DataInfoTestAdapt extends TestUtil {
@BeforeClass static public void setup() { stall_till_cloudsize(1); }
@Test public void testInteractionTrainTestSplitAdapt() {
DataInfo dinfo=null, scoreInfo=null;
Frame fr=null, expanded=null;
Frame[] frSplits=null, expandSplits=null;
String[] interactions = new String[]{"class", "sepal_len"};
boolean useAll=false;
boolean standardize=false; // golden frame is standardized before splitting, while frame we want to check would be standardized post-split (not exactly what we want!)
boolean skipMissing=true;
try {
fr = parse_test_file(Key.make("a.hex"), "smalldata/iris/iris_wheader.csv");
fr.swap(3, 4);
expanded = GLMModel.GLMOutput.expand(fr, interactions, useAll, standardize,skipMissing); // here's the "golden" frame
// now split fr and expanded
long seed;
frSplits = ShuffleSplitFrame.shuffleSplitFrame(fr, new Key[]{Key.make(), Key.make()}, new double[]{0.8, 0.2}, seed = new Random().nextLong());
expandSplits = ShuffleSplitFrame.shuffleSplitFrame(expanded, new Key[]{Key.make(), Key.make()}, new double[]{0.8, 0.2}, seed);
// check1: verify splits. expand frSplits with DataInfo and check against expandSplits
checkSplits(frSplits,expandSplits,interactions,useAll,standardize);
// now take the test frame from frSplits, and adapt it to a DataInfo built on the train frame
dinfo = makeInfo(frSplits[0], interactions, useAll, standardize);
GLMModel.GLMParameters parms = new GLMModel.GLMParameters();
parms._response_column = "petal_wid";
Model.adaptTestForTrain(frSplits[1],null,null,dinfo._adaptedFrame.names(),dinfo._adaptedFrame.domains(),parms,true,false,interactions,null,null, false);
scoreInfo = dinfo.scoringInfo(dinfo._adaptedFrame._names,frSplits[1]);
checkFrame(scoreInfo,expandSplits[1]);
} finally {
cleanup(fr,expanded);
cleanup(frSplits);
cleanup(expandSplits);
cleanup(dinfo, scoreInfo);
}
}
@Test public void testInteractionTrainTestSplitAdaptAirlines() {
DataInfo dinfo=null, scoreInfo=null;
Frame frA=null, fr=null, expanded=null;
Frame[] frSplits=null, expandSplits=null;
String[] interactions = new String[]{"CRSDepTime", "Origin"};
String[] keepColumns = new String[]{
"Year", "Month" , "DayofMonth" , "DayOfWeek",
"CRSDepTime" , "CRSArrTime" , "UniqueCarrier" , "CRSElapsedTime",
"Origin" , "Dest" , "Distance" , "IsDepDelayed",
};
boolean useAll=false;
boolean standardize=false; // golden frame is standardized before splitting, while frame we want to check would be standardized post-split (not exactly what we want!)
boolean skipMissing=false;
try {
frA = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip");
fr = frA.subframe(keepColumns);
expanded = GLMModel.GLMOutput.expand(fr, interactions, useAll, standardize, skipMissing); // here's the "golden" frame
// now split fr and expanded
long seed;
frSplits = ShuffleSplitFrame.shuffleSplitFrame(fr, new Key[]{Key.make(), Key.make()}, new double[]{0.8, 0.2}, seed = new Random().nextLong());
expandSplits = ShuffleSplitFrame.shuffleSplitFrame(expanded, new Key[]{Key.make(), Key.make()}, new double[]{0.8, 0.2}, seed);
// check1: verify splits. expand frSplits with DataInfo and check against expandSplits
checkSplits(frSplits,expandSplits,interactions,useAll,standardize,skipMissing);
// now take the test frame from frSplits, and adapt it to a DataInfo built on the train frame
dinfo = makeInfo(frSplits[0], interactions, useAll, standardize,skipMissing);
GLMModel.GLMParameters parms = new GLMModel.GLMParameters();
parms._response_column = "IsDepDelayed";
Model.adaptTestForTrain(frSplits[1],null,null,dinfo._adaptedFrame.names(),dinfo._adaptedFrame.domains(),parms,true,false,interactions,null,null, false);
scoreInfo = dinfo.scoringInfo(dinfo._adaptedFrame._names,frSplits[1]);
checkFrame(scoreInfo,expandSplits[1], skipMissing);
} finally {
cleanup(fr,frA,expanded);
cleanup(frSplits);
cleanup(expandSplits);
cleanup(dinfo, scoreInfo);
}
}
private void cleanup(Frame... fr) {
for(Frame f: fr) if( null!=f ) f.delete();
}
private void cleanup(DataInfo... di) {
for(DataInfo d: di) if( null!=d ) {
d.dropInteractions();
d.remove();
}
}
private void checkSplits(Frame frSplits[], Frame goldSplits[], String[] interactions, boolean useAll, boolean standardize) {
checkSplits(frSplits,goldSplits,interactions,useAll,standardize,false);
}
private void checkSplits(Frame frSplits[], Frame goldSplits[], String[] interactions, boolean useAll, boolean standardize, boolean skipMissing) {
for(int i=0;i<frSplits.length;++i)
checkFrame(makeInfo(frSplits[i],interactions,useAll,standardize,skipMissing),goldSplits[i], skipMissing);
}
private static DataInfo makeInfo(Frame fr, String[] interactions, boolean useAll, boolean standardize) {
return makeInfo(fr,interactions,useAll,standardize,true);
}
private static DataInfo makeInfo(Frame fr, String[] interactions, boolean useAll, boolean standardize, boolean skipMissing) {
return new DataInfo(
fr, // train
null, // valid
1, // num responses
useAll, // use all factor levels
standardize?DataInfo.TransformType.STANDARDIZE:DataInfo.TransformType.NONE, // predictor transform
DataInfo.TransformType.NONE, // response transform
skipMissing, // skip missing
false, // impute missing
false, // missing bucket
false, // weight
false, // offset
false, // fold
interactions // interactions
);
}
private void checkFrame(final Frame checkMe, final Frame gold) {
Vec[] vecs = new Vec[checkMe.numCols()+gold.numCols()];
new MRTask() {
@Override public void map(Chunk[] cs) {
int off=checkMe.numCols();
for(int i=0;i<off;++i) {
for(int r=0;r<cs[0]._len;++r) {
double check = cs[i].atd(r);
double gold = cs[i+off].atd(r);
if( Math.abs(check-gold) > 1e-12 )
throw new RuntimeException("bonk");
}
}
}
}.doAll(vecs);
}
private void checkFrame(final DataInfo di, final Frame gold) { checkFrame(di,gold,true); }
private void checkFrame(final DataInfo di, final Frame gold, final boolean skipMissing) {
try {
Vec[] vecs = new Vec[di._adaptedFrame.numCols()+gold.numCols()];
System.arraycopy(di._adaptedFrame.vecs(),0,vecs,0,di._adaptedFrame.numCols());
System.arraycopy(gold.vecs(), 0, vecs, di._adaptedFrame.numCols(), gold.numCols());
new MRTask() {
@Override public void map(Chunk[] cs) {
int off = di._adaptedFrame.numCols();
DataInfo.Row r = di.newDenseRow();
// DataInfo.Row rows[] = di.extractSparseRows(cs);
for (int i = 0; i < cs[0]._len; ++i) {
// DataInfo.Row r = rows[i];
di.extractDenseRow(cs, i, r);
if( skipMissing && r.isBad() ) continue;
for (int j = 0; j < di.fullN(); ++j) {
double goldValue = cs[off+j].atd(i);
double thisValue = r.get(j); // - (di._normSub[j - di.numStart()] * di._normMul[j-di.numStart()]);
double diff = Math.abs(goldValue - thisValue);
if (diff > 1e-12) {
if( !skipMissing && diff < 10 )
System.out.println("row mismatch: " + i + " column= " + j + "; diff= " + diff + " but not skipping missing, so due to discrepancies in taking mean on split frames");
else throw new RuntimeException("bonk");
}
}
}
}
}.doAll(vecs);
} finally {
di.dropInteractions();
di.remove();
}
}
}