package hex; import hex.KMeans2.KMeans2Model; import hex.KMeans2.Initialization; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.FVecTest; import water.fvec.Frame; import water.fvec.ParseDataset2; import water.util.Log; import water.util.Log.Tag.Sys; import java.util.Arrays; import java.util.Random; public class KMeans2Test extends TestUtil { private static final long SEED = 8683452581122892189L; private static final double SIGMA = 3; public static final void testHTML(KMeans2Model m) { StringBuilder sb = new StringBuilder(); KMeans2.KMeans2ModelView kmv = new KMeans2.KMeans2ModelView(); kmv.model = m; kmv.toHTML(sb); assert(sb.length() > 0); } @BeforeClass public static void stall() { stall_till_cloudsize(JUnitRunnerDebug.NODES); } @Test public void test1Dimension() { double[] data = new double[] { 1.2, 5.6, 3.7, 0.6, 0.1, 2.6 }; double[][] rows = new double[data.length][1]; for( int i = 0; i < rows.length; i++ ) rows[i][0] = data[i]; Frame frame = frame(new String[] { "C0" }, rows); KMeans2 algo; try { algo = new KMeans2(); algo.source = frame; algo.k = 2; algo.initialization = Initialization.Furthest; algo.max_iter = 100; algo.seed = SEED; algo.invoke(); KMeans2Model res = UKV.get(algo.dest()); testHTML(res); Assert.assertTrue(res.get_params().state == Job.JobState.DONE); //HEX-1817 double[][] clusters = res.centers; Assert.assertEquals(1.125, clusters[0][0], 0.000001); Assert.assertEquals(4.65, clusters[1][0], 0.000001); res.delete(); } finally { frame.delete(); } } @Test public void testGaussian() { testGaussian(10000); } public void testGaussian(int rows) { final int columns = 100; double[][] goals = new double[8][columns]; double[][] array = gauss(columns, rows, goals); String[] names = new String[columns]; for( int i = 0; i < names.length; i++ ) names[i] = "C" + i; Frame frame = frame(names, array); KMeans2 algo; try { algo = new KMeans2(); algo.source = frame; algo.k = goals.length; algo.initialization = Initialization.Furthest; algo.max_iter = 100; algo.seed = SEED; Timer t = new Timer(); algo.invoke(); KMeans2Model res = UKV.get(algo.dest()); testHTML(res); Log.debug(Sys.KMEAN, " testGaussian rows:" + rows + ", ms:" + t); double[][] clusters = res.centers; for( double[] goal : goals ) { boolean found = false; for( double[] cluster : clusters ) { if( match(cluster, goal) ) { found = true; break; } } Assert.assertTrue(found); } res.delete(); } finally { frame.delete(); } } public static double[][] gauss(int columns, int rows, double[][] goals) { // rows and cols are reversed on this one for va_maker Random rand = new Random(SEED); for( int goal = 0; goal < goals.length; goal++ ) for( int c = 0; c < columns; c++ ) goals[goal][c] = rand.nextDouble() * 100; double[][] array = new double[rows][columns]; gauss(goals, array); return array; } public static void gauss(double[][] goals, double[][] array) { Random rand = new Random(SEED); for( int r = 0; r < array.length; r++ ) { final int goal = rand.nextInt(goals.length); for( int c = 0; c < array[r].length; c++ ) array[r][c] = goals[goal][c] + rand.nextGaussian() * SIGMA; } } static boolean match(double[] cluster, double[] goal) { for( int i = 0; i < cluster.length; i++ ) if( Math.abs(cluster[i] - goal[i]) > 1 ) return false; return true; } static double dist(double[] cluster, double[] goal) { double sum = 0; for( int i = 0; i < cluster.length; i++ ) { double d = cluster[i] - goal[i]; sum += d * d; } return Math.sqrt(sum / cluster.length); } @Test public void testAirline() { Key dest = Key.make("dest"); Frame frame = parseFrame(dest, "smalldata/airlines/allyears2k.zip"); KMeans2 algo = new KMeans2(); algo.source = frame; algo.k = 8; algo.initialization = Initialization.Furthest; algo.max_iter = 100; algo.seed = SEED; Timer t = new Timer(); algo.invoke(); Log.debug(Sys.KMEAN, "ms= " + t); KMeans2Model res = UKV.get(algo.dest()); testHTML(res); Assert.assertEquals(algo.k, res.centers.length); frame.delete(); res.delete(); } @Test public void testSphere() { Key dest = Key.make("dest"); Frame frame = parseFrame(dest, "smalldata/syn_sphere2.csv"); KMeans2 algo = new KMeans2(); algo.source = frame; algo.k = 3; algo.initialization = Initialization.Furthest; algo.max_iter = 100; algo.seed = SEED; Timer t = new Timer(); algo.invoke(); Log.debug(Sys.KMEAN, "ms= " + t); KMeans2Model res = UKV.get(algo.dest()); testHTML(res); Assert.assertEquals(algo.k, res.centers.length); frame.delete(); res.delete(); } private double[] d(double... ds) { return ds; } boolean close(double[] a, double[] b) { for (int i=0;i<a.length;++i) { if (Math.abs(a[i]-b[i]) > 1e-8) return false; } return true; } @Test public void testCentroids(){ String data = "1, 0, 0\n" + "0, 1, 0\n" + "0, 0, 1\n"; Frame fr = null; try { Key k = FVecTest.makeByteVec("yada", data); fr = ParseDataset2.parse(Key.make(), new Key[]{k}); for( boolean normalize : new boolean[]{false, true}) { for( Initialization init : new Initialization[]{Initialization.None, Initialization.PlusPlus, Initialization.Furthest}) { KMeans2 parms = new KMeans2(); parms.source = fr; parms.k = 3; parms.normalize = normalize; parms.max_iter = 100; parms.initialization = init; parms.seed = 0; parms.invoke(); KMeans2Model kmm = UKV.get(parms.dest()); Assert.assertTrue(kmm.centers[0][0] + kmm.centers[0][1] + kmm.centers[0][2] == 1); Assert.assertTrue(kmm.centers[1][0] + kmm.centers[1][1] + kmm.centers[1][2] == 1); Assert.assertTrue(kmm.centers[2][0] + kmm.centers[2][1] + kmm.centers[2][2] == 1); Assert.assertTrue(kmm.centers[0][0] + kmm.centers[1][0] + kmm.centers[2][0] == 1); Assert.assertTrue(kmm.centers[0][0] + kmm.centers[1][0] + kmm.centers[2][0] == 1); Assert.assertTrue(kmm.centers[0][0] + kmm.centers[1][0] + kmm.centers[2][0] == 1); testHTML(kmm); kmm.delete(); } } } finally { if( fr != null ) fr.delete(); } } @Test public void testNAColLast(){ String[] datas = new String[]{ new String( "1, 0, ?\n" + //33% NA in col 3 "0, 2, 0\n" + "0, 0, 3\n" ), new String( "1, ?, 0\n" + //33% NA in col 2 "0, 2, 0\n" + "0, 0, 3\n" ), new String( "?, 0, 0\n" + //33% NA in col 1 "0, 2, 0\n" + "0, 0, 3\n" )}; Frame fr = null; for (String data : datas){ try { Key k = FVecTest.makeByteVec("yada", data); fr = ParseDataset2.parse(Key.make(), new Key[]{k}); for (boolean drop_na : new boolean[]{false, true}) { for (boolean normalize : new boolean[]{false, true}) { for (Initialization init : new Initialization[]{Initialization.None, Initialization.PlusPlus, Initialization.Furthest}) { KMeans2 parms = new KMeans2(); parms.source = fr; parms.k = 3; parms.normalize = normalize; parms.max_iter = 100; parms.initialization = init; parms.drop_na_cols = drop_na; parms.seed = 0; parms.invoke(); KMeans2Model kmm = UKV.get(parms.dest()); testHTML(kmm); kmm.delete(); } } } } finally { if( fr != null ) fr.delete(); } } } }