package hex.kmeans; import hex.Model; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.TestUtil; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; import java.util.Random; public class KMeansRandomTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test public void run() { long seed = 0xDECAF; Random rng = new Random(seed); String[] datasets = new String[]{ "smalldata/logreg/prostate.csv", "smalldata/iris/iris_wheader.csv", "smalldata/junit/weather.csv" }; int testcount = 0; int count = 0; for( String dataset : datasets ) { Frame frame = parse_test_file(dataset); try { for (int centers : new int[]{1, 2, 10, 100}) { for (int max_iter : new int[]{1, 10}) { for (boolean estimate_k : new boolean[]{false, true}) { for (boolean standardize : new boolean[]{false, true}) { for (Model.Parameters.CategoricalEncodingScheme catEncoding : Model.Parameters.CategoricalEncodingScheme.values()) { for (KMeans.Initialization init : new KMeans.Initialization[]{ KMeans.Initialization.Random, KMeans.Initialization.Furthest, KMeans.Initialization.PlusPlus, }) { if (catEncoding == Model.Parameters.CategoricalEncodingScheme.SortByResponse) continue; count++; // if (count!=1303) { // rng.nextDouble(); // rng.nextLong(); // continue; // } if (rng.nextDouble() > 0.2) continue; Frame score = null; KMeansModel.KMeansParameters parms; KMeansModel m = null; try { parms = new KMeansModel.KMeansParameters(); parms._train = frame._key; if(dataset != null && dataset.equals("smalldata/iris/iris_wheader.csv")) parms._ignored_columns = new String[] {"class"}; parms._k = centers; parms._seed = rng.nextLong(); parms._max_iterations = max_iter; parms._standardize = standardize; parms._init = init; parms._estimate_k = estimate_k; parms._categorical_encoding = catEncoding; KMeans job = new KMeans(parms); m = job.trainModel().get(); Assert.assertTrue("Progress not 100%, but " + job._job.progress() *100, job._job.progress() == 1.0); for (int j = 0; j < m._output._k[m._output._k.length-1]; j++) Assert.assertTrue(m._output._size[j] != 0); Assert.assertTrue(m._output._iterations <= max_iter); for (double d : m._output._withinss) Assert.assertFalse(Double.isNaN(d)); Assert.assertFalse(Double.isNaN(m._output._tot_withinss)); for (long o : m._output._size) Assert.assertTrue(o > 0); //have at least one point per centroid for (double[] dc : m._output._centers_raw) for (double d : dc) Assert.assertFalse(Double.isNaN(d)); // make prediction (cluster assignment) score = m.score(frame); Vec.Reader vr = score.anyVec().new Reader(); for (long j = 0; j < score.numRows(); ++j) Assert.assertTrue(vr.at8(j) >= 0 && vr.at8(j) < m._output._k[m._output._k.length-1]); Log.info("Parameters combination " + count + ": PASS"); testcount++; } finally { if (m!=null) m.delete(); if (score!=null) score.delete(); } } } } } } } } finally { frame.delete(); } } Log.info("\n\n============================================="); Log.info("Tested " + testcount + " out of " + count + " parameter combinations."); Log.info("============================================="); } }