package hex.gapstat;
import hex.KMeans2;
import water.*;
import water.Job;
import water.api.DocGen;
import water.fvec.*;
import water.util.Log;
import java.util.Random;
import static water.util.Utils.getDeterRNG;
/**
* Gap Statistic
* This is an algorithm for estimating the optimal number of clusters in p-dimensional data.
* @author spencer_aiello
*
*/
public class GapStatistic extends Job.ColumnsJob {
static final int API_WEAVER = 1;
static public DocGen.FieldDoc[] DOC_FIELDS;
static final String DOC_GET = "gap statistic";
@API(help = "Number of Monte Carlo Bootstrap Replicates", filter = Default.class, lmin = 1, lmax = 100000, json = true)
public int b_max = 10;
@API(help = "The number maximum number of clusters to consider, must be at least 2.", filter = Default.class, json = true, lmin = 2, lmax = 10000)
public int k_max = 10;
@API(help = "Fraction of data size to replicate in each MC simulation.", filter = Default.class, json = true, dmin = 0, dmax = 1)
public double bootstrap_fraction = .1;
@API(help = "Max iteratiors per clustering.")
public int max_iter = 50;
@API(help = "A random seed.", filter = Default.class, json = true)
public long seed = new Random().nextLong();
@Override protected void execImpl() {
logStart();
GapStatisticModel model = initModel();
buildModel(model);
cleanup();
remove();
}
private GapStatisticModel initModel() {
try {
source.read_lock(self());
int ks = k_max;
double[] wks = new double[ks];
double[] wkbs = new double[ks];
double[] sk = new double[ks];
return new GapStatisticModel(destination_key, source._key, source, k_max, wks, wkbs, sk, k_max, b_max, 1, 0);
}
finally {
source.unlock(self());
}
}
private void buildModel(GapStatisticModel gs_model) {
try {
source.read_lock(self());
if (gs_model == null) gs_model = UKV.get(dest());
gs_model.delete_and_lock(self());
for (int k = 1; k <= k_max; ++k) {
if (this.isCancelledOrCrashed()) {
throw new JobCancelledException();
}
KMeans2 km = new KMeans2();
km.source = source;
km.cols = cols;
km.max_iter = max_iter;
km.k = k;
km.initialization = KMeans2.Initialization.Furthest;
km.invoke();
KMeans2.KMeans2Model res = UKV.get(km.dest());
Futures fs = new Futures();
DKV.remove(Key.make(km.dest()+"_clusters"), fs);
gs_model.wks[k - 1] = Math.log(res.mse());
double[] bwkbs = new double[b_max];
for (int b = 0; b < b_max; ++b) {
if (this.isCancelledOrCrashed()) {
throw new JobCancelledException();
}
Frame bs = new MRTask2() {
@Override public void map(Chunk[] chks, NewChunk[] nchks) {
final Random rng = getDeterRNG(seed + chks[0].cidx());
for (int row = 0; row < Math.floor(bootstrap_fraction * chks[0]._len); ++row) {
for (int col = 0; col < chks.length; ++ col) {
if (source.vecs()[col].isConst()) {
nchks[col].addNum(source.vecs()[col].max());
continue;
}
if (source.vecs()[col].isEnum()) {
nchks[col].addEnum((int)chks[col].at8(row));
continue;
}
double d = rng.nextDouble() * source.vecs()[col].max() + source.vecs()[col].min();
nchks[col].addNum(d);
}
}
}
}.doAll(source.numCols(), source).outputFrame(source.names(), source.domains());
KMeans2 km_bs = new KMeans2();
km_bs.source = bs;
km_bs.cols = cols;
km_bs.max_iter = max_iter;
km_bs.k = k;
km_bs.initialization = KMeans2.Initialization.Furthest;
km_bs.invoke();
KMeans2.KMeans2Model res_bs = UKV.get(km_bs.dest());
fs = new Futures();
DKV.remove(Key.make(km_bs.dest()+"_clusters"), fs);
bwkbs[b] = Math.log(res_bs.mse());
gs_model.b = b+1;
gs_model.update(self());
}
double sum_bwkbs = 0.;
for (double d: bwkbs) sum_bwkbs += d;
gs_model.wkbs[k - 1] = sum_bwkbs / b_max;
double sk_2 = 0.;
for (double d: bwkbs) {
sk_2 += (d - gs_model.wkbs[k - 1]) * (d - gs_model.wkbs[k - 1]) * 1. / (double) b_max;
}
gs_model.sk[k - 1] = Math.sqrt(sk_2) * Math.sqrt(1 + 1. / (double) b_max);
gs_model.k = k;
for(int i = 0; i < gs_model.wks.length; ++i) gs_model.gap_stats[i] = gs_model.wkbs[i] - gs_model.wks[i];
gs_model.update(self());
}
gs_model.compute_k_best();
}
catch(JobCancelledException ex) {
Log.info("Gap Statistic Computation was cancelled.");
}
catch(Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
finally {
if (gs_model != null) {
gs_model = UKV.get(dest());
gs_model.unlock(self());
}
source.unlock(self());
emptyLTrash();
}
}
@Override protected Response redirect() {
return GapStatisticProgressPage.redirect(this, self(), dest());
}
}