package hex;
import junit.framework.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;
import water.util.Log;
import java.util.Random;
public class KMeans2RandomTest extends TestUtil {
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
}
@Test
public void run() {
long seed = 0xDECAF;
Random rng = new Random(seed);
String[] datasets = new String[2];
int[][] responses = new int[datasets.length][];
datasets[0] = "smalldata/./logreg/prostate.csv";
responses[0] = new int[]{1, 2, 8}; //CAPSULE (binomial), AGE (regression), GLEASON (multi-class)
datasets[1] = "smalldata/iris/iris.csv";
responses[1] = new int[]{4}; //Iris-type (multi-class)
int testcount = 0;
int count = 0;
for (int i = 0; i < datasets.length; ++i) {
String dataset = datasets[i];
Key file = NFSFileVec.make(find_test_file(dataset));
Frame frame = ParseDataset2.parse(Key.make(), new Key[]{file});
try {
for (int clusters : new int[]{1, 10}) {
for (int max_iter : new int[]{1, 10, 100}) {
for (boolean normalize : new boolean[]{false, true}) {
for (boolean drop_na_cols : new boolean[]{false, true}) {
for (KMeans2.Initialization init : new KMeans2.Initialization[]{
KMeans2.Initialization.Furthest,
KMeans2.Initialization.None,
KMeans2.Initialization.PlusPlus}) {
count++;
KMeans2 k = new KMeans2();
k.k = clusters;
k.initialization = init;
k.destination_key = Key.make();
k.seed = rng.nextLong();
k.source = frame;
k.max_iter = max_iter;
k.normalize = normalize;
k.drop_na_cols = drop_na_cols;
k.invoke();
KMeans2.KMeans2Model m = null;
Frame score = null;
Frame ref = null;
try {
m = UKV.get(k.dest());
for (double d : m.within_cluster_variances) Assert.assertFalse(Double.isNaN(d));
Assert.assertFalse(Double.isNaN(m.total_within_SS));
for (long o : m.size) Assert.assertTrue(o > 0); //have at least one point per centroid
for (double[] dc : m.centers) for (double d : dc) Assert.assertFalse(Double.isNaN(d));
KMeans2Test.testHTML(m);
// make prediction (cluster assignment)
score = m.score(frame);
ref = UKV.get(m._clustersKey);
for (long j = 0; j < score.numRows(); ++j) {
org.junit.Assert.assertTrue(score.anyVec().at8(j) >= 0 && score.anyVec().at8(j) < clusters); //check sanity
}
Log.info("Parameters combination " + count + ": PASS");
testcount++;
} finally {
if (m != null) m.delete();
if (score != null) score.delete();
if (ref != null) ref.delete();
}
}
}
}
}
}
} finally {
frame.delete();
}
}
Log.info("\n\n=============================================");
Log.info("Tested " + testcount + " out of " + count + " parameter combinations.");
Log.info("=============================================");
}
}