package water.util; import hex.CreateFrame; import hex.Model; import hex.ToEigenVec; import org.junit.Assert; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import org.junit.BeforeClass; import org.junit.Test; import org.hamcrest.CoreMatchers; import water.DKV; import water.Key; import water.Scope; import water.TestUtil; import water.fvec.Frame; import water.fvec.TestFrameBuilder; import water.fvec.Vec; import java.util.HashMap; /** * Test FrameUtils interface. */ public class FrameUtilsTest extends TestUtil { @BeforeClass static public void setup() { stall_till_cloudsize(1); } @Test public void testCategoricalColumnsBinaryEncoding() { int numNoncatColumns = 10; int[] catSizes = {2, 3, 4, 5, 7, 8, 9, 15, 16, 30, 31, 127, 255, 256}; int[] expBinarySizes = {2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 7, 8, 9}; String[] catNames = {"duo", "Trinity", "Quart", "Star rating", "Dwarves", "Octopus legs", "Planets", "Game of Fifteen", "Halfbyte", "Days30", "Days31", "Periodic Table", "AlmostByte", "Byte"}; Assert.assertEquals(catSizes.length, expBinarySizes.length); Assert.assertEquals(catSizes.length, catNames.length); int totalExpectedColumns = numNoncatColumns; for (int s : expBinarySizes) totalExpectedColumns += s; Key<Frame> frameKey = Key.make(); CreateFrame cf = new CreateFrame(frameKey); cf.rows = 100; cf.cols = numNoncatColumns; cf.categorical_fraction = 0.0; cf.integer_fraction = 0.3; cf.binary_fraction = 0.1; cf.time_fraction = 0.2; cf.string_fraction = 0.1; Frame mainFrame = cf.execImpl().get(); assert mainFrame != null : "Unable to create a frame"; Frame[] auxFrames = new Frame[catSizes.length]; Frame transformedFrame = null; try { for (int i = 0; i < catSizes.length; ++i) { CreateFrame ccf = new CreateFrame(); ccf.rows = 100; ccf.cols = 1; ccf.categorical_fraction = 1; ccf.integer_fraction = 0; ccf.binary_fraction = 0; ccf.time_fraction = 0; ccf.string_fraction = 0; ccf.factors = catSizes[i]; auxFrames[i] = ccf.execImpl().get(); auxFrames[i]._names[0] = catNames[i]; mainFrame.add(auxFrames[i]); } FrameUtils.CategoricalBinaryEncoder cbed = new FrameUtils.CategoricalBinaryEncoder(mainFrame, null); transformedFrame = cbed.exec().get(); assert transformedFrame != null : "Unable to transform a frame"; Assert.assertEquals("Wrong number of columns after converting to binary encoding", totalExpectedColumns, transformedFrame.numCols()); for (int i = 0; i < numNoncatColumns; ++i) { Assert.assertEquals(mainFrame.name(i), transformedFrame.name(i)); Assert.assertEquals(mainFrame.types()[i], transformedFrame.types()[i]); } for (int i = 0, colOffset = numNoncatColumns; i < catSizes.length; colOffset += expBinarySizes[i++]) { for (int j = 0; j < expBinarySizes[i]; ++j) { int jj = colOffset + j; Assert.assertTrue("A categorical column should be transformed into several binary ones (col "+jj+")", transformedFrame.vec(jj).isBinary()); Assert.assertThat("Transformed categorical column should carry the name of the original column", transformedFrame.name(jj), CoreMatchers.startsWith(mainFrame.name(numNoncatColumns+i) + ":")); } } } catch (Throwable e) { e.printStackTrace(); throw e; } finally { mainFrame.delete(); if (transformedFrame != null) transformedFrame.delete(); for (Frame f : auxFrames) if (f != null) f.delete(); } } @Test public void testOneHotExplicitEncoder() { Scope.enter(); try { Frame f = new TestFrameBuilder() .withName("testFrame") .withColNames("NumCol", "CatCol1", "CatCol2") .withVecTypes(Vec.T_NUM, Vec.T_CAT, Vec.T_CAT) .withDataForCol(0, ard(Double.NaN, 1, 2, 3, 4, 5.6, 7)) .withDataForCol(1, ar("A", "B", "C", "E", "F", "I", "J")) .withDataForCol(2, ar("A", "B", "A", "C", null, "B", "A")) .withChunkLayout(2, 2, 2, 1) .build(); Frame result = FrameUtils.categoricalEncoder(f, new String[]{"CatCol1"}, Model.Parameters.CategoricalEncodingScheme.OneHotExplicit, null); Scope.track(result); assertArrayEquals( new String[]{"NumCol", "CatCol2.A", "CatCol2.B", "CatCol2.C", "CatCol2.missing(NA)", "CatCol1"}, result.names()); // check that original columns are the same assertVecEquals(f.vec("NumCol"), result.vec("NumCol"), 1e-6); assertCatVecEquals(f.vec("CatCol1"), result.vec("CatCol1")); // validate 1-hot encoding Vec catVec = f.vec("CatCol2"); for (long i = 0; i < catVec.length(); i++) { String hotCol = "CatCol2." + (catVec.isNA(i) ? "missing(NA)" : catVec.domain()[(int) catVec.at8(i)]); for (String col : result.names()) if (col.startsWith("CatCol2.")) { long expectedVal = hotCol.equals(col) ? 1 : 0; assertEquals("Value of column " + col + " in row = " + i + " matches", expectedVal, result.vec(col).at8(i)); } } } finally { Scope.exit(); } } }