package hex.kmeans; import hex.*; import hex.util.LinearAlgebraUtils; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import water.*; import water.exceptions.H2OModelBuilderIllegalArgumentException; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.util.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import static hex.genmodel.GenModel.Kmeans_preprocessData; /** * Scalable K-Means++ (KMeans||)<br> * http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf<br> * http://www.youtube.com/watch?v=cigXAxV3XcY */ public class KMeans extends ClusteringModelBuilder<KMeansModel,KMeansModel.KMeansParameters,KMeansModel.KMeansOutput> { @Override public ToEigenVec getToEigenVec() { return LinearAlgebraUtils.toEigen; } // Convergence tolerance final static private double TOLERANCE = 1e-4; @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Clustering }; } @Override public boolean havePojo() { return true; } @Override public boolean haveMojo() { return true; } public enum Initialization { Random, PlusPlus, Furthest, User } /** Start the KMeans training Job on an F/J thread. */ @Override protected KMeansDriver trainModelImpl() { return new KMeansDriver(); } // Called from an http request public KMeans( KMeansModel.KMeansParameters parms ) { super(parms ); init(false); } public KMeans( KMeansModel.KMeansParameters parms, Job job) { super(parms,job); init(false); } public KMeans(boolean startup_once) { super(new KMeansModel.KMeansParameters(),startup_once); } @Override protected void checkMemoryFootPrint() { long mem_usage = 8 /*doubles*/ * _parms._k * _train.numCols() * (_parms._standardize ? 2 : 1); long max_mem = H2O.SELF._heartbeat.get_free_mem(); if (mem_usage > max_mem) { String msg = "Centroids won't fit in the driver node's memory (" + PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem) + ") - try reducing the number of columns and/or the number of categorical factors."; error("_train", msg); } } /** Initialize the ModelBuilder, validating all arguments and preparing the * training frame. This call is expected to be overridden in the subclasses * and each subclass will start with "super.init();". * * Validate K, max_iterations and the number of rows. */ @Override public void init(boolean expensive) { super.init(expensive); if(expensive) if(_parms._fold_column != null) _train.remove(_parms._fold_column); if( _parms._max_iterations <= 0 || _parms._max_iterations > 1e6) error("_max_iterations", " max_iterations must be between 1 and 1e6"); if (_train == null) return; if (_parms._init == Initialization.User && _parms._user_points == null) error("_user_y","Must specify initial cluster centers"); if (_parms._user_points != null) { // Check dimensions of user-specified centers Frame user_points = _parms._user_points.get(); if (user_points == null) error("_user_y", "User-specified points do not refer to a valid frame"); else if (user_points.numCols() != _train.numCols() - numSpecialCols()) error("_user_y", "The user-specified points must have the same number of columns (" + (_train.numCols() - numSpecialCols()) + ") as the training observations"); else if( user_points.numRows() != _parms._k) error("_user_y", "The number of rows in the user-specified points is not equal to k = " + _parms._k); } if (_parms._estimate_k) { if (_parms._user_points!=null) error("_estimate_k", "Cannot estimate k if user_points are provided."); info("_seed", "seed is ignored when estimate_k is enabled."); info("_init", "Initialization scheme is ignored when estimate_k is enabled - algorithm is deterministic."); if (expensive) { boolean numeric = false; for (Vec v : _train.vecs()) { if (v.isNumeric()) { numeric = true; break; } } if (!numeric) { error("_estimate_k", "Cannot estimate k if data has no numeric columns."); } } } if (expensive && error_count() == 0) checkMemoryFootPrint(); } // ---------------------- private final class KMeansDriver extends Driver { private String[][] _isCats; // Categorical columns // Initialize cluster centers double[][] initial_centers(KMeansModel model, final Vec[] vecs, final double[] means, final double[] mults, final int[] modes, int k) { // Categoricals use a different distance metric than numeric columns. model._output._categorical_column_count=0; _isCats = new String[vecs.length][]; for( int v=0; v<vecs.length; v++ ) { _isCats[v] = vecs[v].isCategorical() ? new String[0] : null; if (_isCats[v] != null) model._output._categorical_column_count++; } Random rand = water.util.RandomUtils.getRNG(_parms._seed-1); double centers[][]; // Cluster centers if( null != _parms._user_points ) { // User-specified starting points Frame user_points = _parms._user_points.get(); int numCenters = (int)user_points.numRows(); int numCols = model._output.nfeatures(); centers = new double[numCenters][numCols]; Vec[] centersVecs = user_points.vecs(); // Get the centers and standardize them if requested for (int r=0; r<numCenters; r++) { for (int c=0; c<numCols; c++){ centers[r][c] = centersVecs[c].at(r); centers[r][c] = Kmeans_preprocessData(centers[r][c], c, means, mults, modes); } } } else { // Random, Furthest, or PlusPlus initialization if (_parms._init == Initialization.Random) { // Initialize all cluster centers to random rows centers = new double[k][model._output.nfeatures()]; for (double[] center : centers) randomRow(vecs, rand, center, means, mults, modes); } else { centers = new double[1][model._output.nfeatures()]; // Initialize first cluster center to random row randomRow(vecs, rand, centers[0], means, mults, modes); model._output._iterations = 0; while (model._output._iterations < 5) { // Sum squares distances to cluster center SumSqr sqr = new SumSqr(centers, means, mults, modes, _isCats).doAll(vecs); // Sample with probability inverse to square distance Sampler sampler = new Sampler(centers, means, mults, modes, _isCats, sqr._sqr, k * 3, _parms.getOrMakeRealSeed(), hasWeightCol()).doAll(vecs); centers = ArrayUtils.append(centers, sampler._sampled); // Fill in sample centers into the model if (stop_requested()) return null; // Stopped/cancelled model._output._centers_raw = destandardize(centers, _isCats, means, mults); model._output._tot_withinss = sqr._sqr / _train.numRows(); model._output._iterations++; // One iteration done model.update(_job); // Make early version of model visible, but don't update progress using update(1) } // Recluster down to k cluster centers centers = recluster(centers, rand, k, _parms._init, _isCats); model._output._iterations = 0; // Reset iteration count } } assert(centers.length == k); return centers; } // Number of reinitialization attempts for preventing empty clusters transient private int _reinit_attempts; // Handle the case where some centers go dry. Rescue only 1 cluster // per iteration ('cause we only tracked the 1 worst row) boolean cleanupBadClusters( LloydsIterationTask task, final Vec[] vecs, final double[][] centers, final double[] means, final double[] mults, final int[] modes ) { // Find any bad clusters int clu; for( clu=0; clu<centers.length; clu++ ) if( task._size[clu] == 0 ) break; if( clu == centers.length ) return false; // No bad clusters long row = task._worst_row; Log.warn("KMeans: Re-initializing cluster " + clu + " to row " + row); data(centers[clu] = task._cMeans[clu], vecs, row, means, mults, modes); task._size[clu] = 1; //FIXME: PUBDEV-871 Some other cluster had their membership count reduced by one! (which one?) // Find any MORE bad clusters; we only fixed the first one for( clu=0; clu<centers.length; clu++ ) if( task._size[clu] == 0 ) break; if( clu == centers.length ) return false; // No MORE bad clusters // If we see 2 or more bad rows, just re-run Lloyds to get the // next-worst row. We don't count this as an iteration, because // we're not really adjusting the centers, we're trying to get // some centers *at-all*. Log.warn("KMeans: Re-running Lloyds to re-init another cluster"); if (_reinit_attempts++ < centers.length) { return true; // Rerun Lloyds, and assign points to centroids } else { _reinit_attempts = 0; return false; } } // Compute all interesting KMeans stats (errors & variances of clusters, // etc). Return new centers. double[][] computeStatsFillModel(LloydsIterationTask task, KMeansModel model, final Vec[] vecs, final double[] means, final double[] mults, final int[] modes, int k) { // Fill in the model based on original destandardized centers if (model._parms._standardize) { model._output._centers_std_raw = task._cMeans; } model._output._centers_raw = destandardize(task._cMeans, _isCats, means, mults); model._output._size = task._size; model._output._withinss = task._cSqr; double ssq = 0; // sum squared error for( int i=0; i<k; i++ ) ssq += model._output._withinss[i]; // sum squared error all clusters model._output._tot_withinss = ssq; // Sum-of-square distance from grand mean if(k == 1) { model._output._totss = model._output._tot_withinss; } else { // If data already standardized, grand mean is just the origin TotSS totss = new TotSS(means,mults,modes, train().domains(), train().cardinality()).doAll(vecs); model._output._totss = totss._tss; } model._output._betweenss = model._output._totss - model._output._tot_withinss; // MSE between-cluster model._output._iterations++; model._output._history_withinss = ArrayUtils.copyAndFillOf( model._output._history_withinss, model._output._history_withinss.length+1, model._output._tot_withinss); model._output._k = ArrayUtils.copyAndFillOf(model._output._k, model._output._k.length+1, k); model._output._training_time_ms = ArrayUtils.copyAndFillOf(model._output._training_time_ms, model._output._training_time_ms.length+1, System.currentTimeMillis()); model._output._reassigned_count = ArrayUtils.copyAndFillOf(model._output._reassigned_count, model._output._reassigned_count.length+1, task._reassigned_count); // Two small TwoDimTables - cheap model._output._model_summary = createModelSummaryTable(model._output); model._output._scoring_history = createScoringHistoryTable(model._output); // Take the cluster stats from the model, and assemble them into a model metrics object model._output._training_metrics = makeTrainingMetrics(model); return task._cMeans; // New centers } // Main worker thread @Override public void computeImpl() { KMeansModel model = null; Key bestOutputKey = Key.make(); try { init(true); // Do lock even before checking the errors, since this block is finalized by unlock // (not the best solution, but the code is more readable) // Something goes wrong if( error_count() > 0 ) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(KMeans.this); // The model to be built // Set fold_column to null and will be added back into model parameter after String fold_column = _parms._fold_column; _parms._fold_column = null; model = new KMeansModel(dest(), _parms, new KMeansModel.KMeansOutput(KMeans.this)); model.delete_and_lock(_job); int startK = _parms._estimate_k ? 1 : _parms._k; // final Vec vecs[] = _train.vecs(); // mults & means for standardization final double[] means = _train.means(); // means are used to impute NAs final double[] mults = _parms._standardize ? _train.mults() : null; final int [] impute_cat = new int[vecs.length]; for(int i = 0; i < vecs.length; i++) impute_cat[i] = vecs[i].isNumeric() ? -1 : DataInfo.imputeCat(vecs[i],true); model._output._normSub = means; model._output._normMul = mults; model._output._mode = impute_cat; // Initialize cluster centers and standardize if requested double[][] centers = initial_centers(model,vecs,means,mults,impute_cat, startK); if( centers==null ) return; // Stopped/cancelled during center-finding boolean work_unit_iter = !_parms._estimate_k; // --- // Run the main KMeans Clustering loop // Stop after enough iterations or reassigned_count < TOLERANCE * num_rows double sum_squares = 0; final double rel_improvement_cutoff = Math.min(0.02 + 10. / _train.numRows() + 2.5 / Math.pow(model._output.nfeatures(), 2), 0.8); if (_parms._estimate_k) Log.info("Cutoff for relative improvement in within_cluster_sum_of_squares: " + rel_improvement_cutoff); Vec[] vecs2 = Arrays.copyOf(vecs, vecs.length+1); vecs2[vecs2.length-1] = vecs2[0].makeCon(-1); for (int k = startK; k <= _parms._k; ++k) { Log.info("Running Lloyds iteration for " + k + " centroids."); model._output._iterations = 0; // Loop ends only when iterations > max_iterations with strict inequality double[][] lo=null, hi=null; boolean stop = false; do { //Lloyds algorithm assert(centers.length == k); LloydsIterationTask task = new LloydsIterationTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(vecs2); //1 PASS OVER THE DATA // Pick the max categorical level for cluster center max_cats(task._cMeans, task._cats, _isCats); // Handle the case where some centers go dry. Rescue only 1 cluster // per iteration ('cause we only tracked the 1 worst row) if( !_parms._estimate_k && cleanupBadClusters(task,vecs,centers,means,mults,impute_cat) ) continue; // Compute model stats; update standardized cluster centers centers = computeStatsFillModel(task, model, vecs, means, mults, impute_cat, k); if (model._parms._score_each_iteration) Log.info(model._output._model_summary); lo = task._lo; hi = task._hi; if (work_unit_iter) { model.update(_job); // Update model in K/V store _job.update(1); //1 more Lloyds iteration } stop = (task._reassigned_count < Math.max(1,train().numRows()*TOLERANCE) || model._output._iterations >= _parms._max_iterations); if (stop) { if (model._output._iterations < _parms._max_iterations) Log.info("Lloyds converged after " + model._output._iterations + " iterations."); else Log.info("Lloyds stopped after " + model._output._iterations + " iterations."); } } while (!stop); double sum_squares_now = model._output._tot_withinss; double rel_improvement; if (sum_squares==0) { rel_improvement = 1; } else { rel_improvement = (sum_squares - sum_squares_now) / sum_squares; } Log.info("Relative improvement in total withinss: " + rel_improvement); sum_squares = sum_squares_now; if (_parms._estimate_k && k > 1) { boolean outerConverged = rel_improvement < rel_improvement_cutoff; if (outerConverged) { KMeansModel.KMeansOutput best = DKV.getGet(bestOutputKey); model._output = best; Log.info("Converged. Retrieving the best model with k=" + model._output._k[model._output._k.length-1]); break; } } if (!work_unit_iter) { DKV.put(bestOutputKey, IcedUtils.deepCopy(model._output)); //store a clone to avoid sharing the state between DKV and here model.update(_job); // Update model in K/V store _job.update(1); //1 more round for auto-clustering } if (lo != null && hi != null && _parms._estimate_k) centers = splitLargestCluster(centers, lo, hi, means, mults, impute_cat, vecs2, k); } //k-finder vecs2[vecs2.length-1].remove(); // Create metrics by scoring on training set otherwise scores are based on last Lloyd iteration model.score(_train).delete(); model._output._training_metrics = ModelMetrics.getFromDKV(model,_train); Log.info(model._output._model_summary); Log.info(model._output._scoring_history); Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString()); // At the end: validation scoring (no need to gather scoring history) if (_valid != null) { model.score(_parms.valid()).delete(); //this appends a ModelMetrics on the validation set model._output._validation_metrics = ModelMetrics.getFromDKV(model,_parms.valid()); } model._parms._fold_column = fold_column; model.update(_job); // Update model in K/V store } finally { if( model != null ) model.unlock(_job); DKV.remove(bestOutputKey); } } double[][] splitLargestCluster(double[][] centers, double[][] lo, double[][] hi, double[] means, double[] mults, int[] impute_cat, Vec[] vecs2, int k) { double[][] newCenters = Arrays.copyOf(centers, centers.length + 1); for (int i = 0; i < centers.length; ++i) newCenters[i] = centers[i].clone(); double maxRange=0; int clusterToSplit=0; int dimToSplit=0; for (int i = 0; i < centers.length; ++i) { double[] range = new double[hi[i].length]; for( int col=0; col<hi[i].length; col++ ) { if (_isCats[col]!=null) continue; // can't split a cluster along categorical direction range[col] = hi[i][col] - lo[i][col]; if ((float)range[col] > (float)maxRange) { //break ties clusterToSplit = i; dimToSplit = col; maxRange = range[col]; } } // Log.info("Range for cluster " + i + ": " + Arrays.toString(range)); } // start out new centroid as a copy of the one to split assert (_isCats[dimToSplit] == null); double splitPoint = newCenters[clusterToSplit][dimToSplit]; // Log.info("Splitting cluster " + clusterToSplit + " in half in dimension " + dimToSplit + " at splitpoint: " + splitPoint); // compute the centroids of the two sub-clusters SplitTask task = new SplitTask(newCenters, means, mults, impute_cat, _isCats, k+1, hasWeightCol(), clusterToSplit, dimToSplit, splitPoint).doAll(vecs2); // Log.info("Splitting: " + Arrays.toString(newCenters[clusterToSplit])); newCenters[clusterToSplit] = task._cMeans[clusterToSplit].clone(); // Log.info("Into One: " + Arrays.toString(newCenters[clusterToSplit])); newCenters[newCenters.length-1] = task._cMeans[newCenters.length-1].clone(); // Log.info(" Two: " + Arrays.toString(newCenters[newCenters.length-1])); return newCenters; } private TwoDimTable createModelSummaryTable(KMeansModel.KMeansOutput output) { List<String> colHeaders = new ArrayList<>(); List<String> colTypes = new ArrayList<>(); List<String> colFormat = new ArrayList<>(); colHeaders.add("Number of Rows"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Number of Clusters"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Number of Categorical Columns"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Number of Iterations"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Within Cluster Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f"); colHeaders.add("Total Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f"); colHeaders.add("Between Cluster Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f"); final int rows = 1; TwoDimTable table = new TwoDimTable( "Model Summary", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), ""); int row = 0; int col = 0; table.set(row, col++, Math.round(_train.numRows() * (hasWeightCol() ? _train.lastVec().mean() : 1))); table.set(row, col++, output._centers_raw.length); table.set(row, col++, output._categorical_column_count); table.set(row, col++, output._k.length-1); table.set(row, col++, output._tot_withinss); table.set(row, col++, output._totss); table.set(row, col++, output._betweenss); return table; } private TwoDimTable createScoringHistoryTable(KMeansModel.KMeansOutput output) { List<String> colHeaders = new ArrayList<>(); List<String> colTypes = new ArrayList<>(); List<String> colFormat = new ArrayList<>(); colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s"); colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s"); colHeaders.add("Iteration"); colTypes.add("long"); colFormat.add("%d"); if (_parms._estimate_k) { colHeaders.add("Number of Clusters"); colTypes.add("long"); colFormat.add("%d"); } colHeaders.add("Number of Reassigned Observations"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Within Cluster Sum Of Squares"); colTypes.add("double"); colFormat.add("%.5f"); final int rows = output._history_withinss.length; TwoDimTable table = new TwoDimTable( "Scoring History", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), ""); int row = 0; for( int i = 0; i<rows; i++ ) { int col = 0; assert(row < table.getRowDim()); assert(col < table.getColDim()); DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); table.set(row, col++, fmt.print(output._training_time_ms[i])); table.set(row, col++, PrettyPrint.msecs(output._training_time_ms[i]-_job.start_time(), true)); table.set(row, col++, i); if (_parms._estimate_k) table.set(row, col++, output._k[i]); table.set(row, col++, output._reassigned_count[i]); table.set(row, col++, output._history_withinss[i]); row++; } return table; } } // ------------------------------------------------------------------------- // Initial sum-of-square-distance to nearest cluster center private static class TotSS extends MRTask<TotSS> { // IN final double[] _means, _mults; final int[] _modes; final String[][] _isCats; final int[] _card; // OUT double _tss; double[] _gc; // Grand center (mean of cols) TotSS(double[] means, double[] mults, int[] modes, String[][] isCats, int[] card) { _means = means; _mults = mults; _modes = modes; _tss = 0; _isCats = isCats; _card = card; // Mean of numeric col is zero when standardized _gc = mults!=null ? new double[means.length] : Arrays.copyOf(means, means.length); for(int i=0; i<means.length; i++) { if(isCats[i] != null) _gc[i] = _modes[i]; } } @Override public void map(Chunk[] cs) { for( int row = 0; row < cs[0]._len; row++ ) { double[] values = new double[cs.length]; // fetch the data - using consistent NA and categorical data handling (same as for training) data(values, cs, row, _means, _mults, _modes); // compute the distance from the (standardized) cluster centroids _tss += hex.genmodel.GenModel.KMeans_distance(_gc, values, _isCats); } } @Override public void reduce(TotSS other) { _tss += other._tss; } } // ------------------------------------------------------------------------- // Initial sum-of-square-distance to nearest cluster center private static class SumSqr extends MRTask<SumSqr> { // IN double[][] _centers; double[] _means, _mults; // Standardization int[] _modes; // Imputation of missing categoricals final String[][] _isCats; // OUT double _sqr; SumSqr( double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats ) { _centers = centers; _means = means; _mults = mults; _modes = modes; _isCats = isCats; } @Override public void map(Chunk[] cs) { double[] values = new double[cs.length]; ClusterDist cd = new ClusterDist(); for( int row = 0; row < cs[0]._len; row++ ) { data(values, cs, row, _means, _mults, _modes); _sqr += minSqr(_centers, values, _isCats, cd); } _means = _mults = null; _modes = null; _centers = null; } @Override public void reduce(SumSqr other) { _sqr += other._sqr; } } // ------------------------------------------------------------------------- // Sample rows with increasing probability the farther they are from any // cluster center. private static class Sampler extends MRTask<Sampler> { // IN double[][] _centers; double[] _means, _mults; // Standardization int[] _modes; // Imputation of missing categoricals final String[][] _isCats; final double _sqr; // Min-square-error final double _probability; // Odds to select this point final long _seed; boolean _hasWeight; // OUT double[][] _sampled; // New cluster centers Sampler( double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, double sqr, double prob, long seed, boolean hasWeight ) { _centers = centers; _means = means; _mults = mults; _modes = modes; _isCats = isCats; _sqr = sqr; _probability = prob; _seed = seed; _hasWeight = hasWeight; } @Override public void map(Chunk[] cs) { int N = cs.length - (_hasWeight?1:0); double[] values = new double[N]; ArrayList<double[]> list = new ArrayList<>(); Random rand = RandomUtils.getRNG(0); ClusterDist cd = new ClusterDist(); for( int row = 0; row < cs[0]._len; row++ ) { rand.setSeed(_seed + cs[0].start()+row); data(values, cs, row, _means, _mults, _modes); double sqr = minSqr(_centers, values, _isCats, cd); if( _probability * sqr > rand.nextDouble() * _sqr ) list.add(values.clone()); } _sampled = new double[list.size()][]; list.toArray(_sampled); _centers = null; _means = _mults = null; _modes = null; } @Override public void reduce(Sampler other) { _sampled = ArrayUtils.append(_sampled, other._sampled); } } // --------------------------------------- // A Lloyd's pass: // Find nearest cluster center for every point // Compute new mean/center & variance & rows for each cluster // Compute distance between clusters // Compute total sqr distance private static class LloydsIterationTask extends MRTask<LloydsIterationTask> { // IN double[][] _centers; double[] _means, _mults; // Standardization int[] _modes; // Imputation of missing categoricals final int _k; final String[][] _isCats; boolean _hasWeight; // OUT double[][] _lo, _hi; // Bounding box double _reassigned_count; double[][] _cMeans; // Means for each cluster long[/*k*/][/*features*/][/*nfactors*/] _cats; // Histogram of cat levels double[] _cSqr; // Sum of squares for each cluster long[] _size; // Number of rows in each cluster long _worst_row; // Row with max err double _worst_err; // Max-err-row's max-err LloydsIterationTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight ) { _centers = centers; _means = means; _mults = mults; _modes = modes; _isCats = isCats; _k = k; _hasWeight = hasWeight; } @Override public void map(Chunk[] cs) { int N = cs.length - (_hasWeight ? 1:0) - 1 /*clusterassignment*/; assert _centers[0].length==N; _lo = new double[_k][N]; for( int clu=0; clu< _k; clu++ ) Arrays.fill(_lo[clu], Double.MAX_VALUE); _hi = new double[_k][N]; for( int clu=0; clu< _k; clu++ ) Arrays.fill(_hi[clu], -Double.MAX_VALUE); _cMeans = new double[_k][N]; _cSqr = new double[_k]; _size = new long[_k]; // Space for cat histograms _cats = new long[_k][N][]; for( int clu=0; clu< _k; clu++ ) for( int col=0; col<N; col++ ) _cats[clu][col] = _isCats[col]==null ? null : new long[cs[col].vec().cardinality()]; _worst_err = 0; Chunk assignment = cs[cs.length-1]; // Find closest cluster center for each row double[] values = new double[N]; // Temp data to hold row as doubles ClusterDist cd = new ClusterDist(); for( int row = 0; row < cs[0]._len; row++ ) { double weight = _hasWeight ? cs[N].atd(row) : 1; if (weight == 0) continue; //skip holdout rows assert(weight == 1); //K-Means only works for weight 1 (or weight 0 for holdout) data(values, cs, row, _means, _mults, _modes); // Load row as doubles closest(_centers, values, _isCats, cd); // Find closest cluster center if (cd._cluster != assignment.at8(row)) { _reassigned_count+=weight; assignment.set(row, cd._cluster); } for( int clu=0; clu< _k; clu++ ) { for( int col=0; col<N; col++ ) { if (cd._cluster == clu) { _lo[clu][col] = Math.min(values[col], _lo[clu][col]); _hi[clu][col] = Math.max(values[col], _hi[clu][col]); } } } int clu = cd._cluster; assert clu != -1; // No broken rows _cSqr[clu] += cd._dist; // Add values and increment counter for chosen cluster for( int col = 0; col < N; col++ ) if( _isCats[col] != null ) _cats[clu][col][(int)values[col]]++; // Histogram the cats else _cMeans[clu][col] += values[col]; // Sum the column centers _size[clu]++; // Track worst row if( cd._dist > _worst_err) { _worst_err = cd._dist; _worst_row = cs[0].start()+row; } } // Scale back down to local mean for( int clu = 0; clu < _k; clu++ ) if( _size[clu] != 0 ) ArrayUtils.div(_cMeans[clu], _size[clu]); _centers = null; _means = _mults = null; _modes = null; } @Override public void reduce(LloydsIterationTask mr) { _reassigned_count += mr._reassigned_count; for( int clu = 0; clu < _k; clu++ ) { long ra = _size[clu]; long rb = mr._size[clu]; double[] ma = _cMeans[clu]; double[] mb = mr._cMeans[clu]; for( int c = 0; c < ma.length; c++ ) // Recursive mean if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb); } ArrayUtils.add(_cats, mr._cats); ArrayUtils.add(_cSqr, mr._cSqr); ArrayUtils.add(_size, mr._size); for( int clu=0; clu< _k; clu++ ) { for( int col=0; col<_lo[clu].length; col++ ) { _lo[clu][col] = Math.min(mr._lo[clu][col], _lo[clu][col]); _hi[clu][col] = Math.max(mr._hi[clu][col], _hi[clu][col]); } } // track global worst-row if( _worst_err < mr._worst_err) { _worst_err = mr._worst_err; _worst_row = mr._worst_row; } } } // A pair result: nearest cluster center and the square distance private static final class ClusterDist { int _cluster; double _dist; } private static double minSqr(double[][] centers, double[] point, String[][] isCats, ClusterDist cd) { return closest(centers, point, isCats, cd, centers.length)._dist; } private static double minSqr(double[][] centers, double[] point, String[][] isCats, ClusterDist cd, int count) { return closest(centers,point,isCats,cd,count)._dist; } private static ClusterDist closest(double[][] centers, double[] point, String[][] isCats, ClusterDist cd) { return closest(centers, point, isCats, cd, centers.length); } /** Return both nearest of N cluster center/centroids, and the square-distance. */ private static ClusterDist closest(double[][] centers, double[] point, String[][] isCats, ClusterDist cd, int count) { int min = -1; double minSqr = Double.MAX_VALUE; for( int cluster = 0; cluster < count; cluster++ ) { double sqr = hex.genmodel.GenModel.KMeans_distance(centers[cluster],point,isCats); if( sqr < minSqr ) { // Record nearest cluster min = cluster; minSqr = sqr; } } cd._cluster = min; // Record nearest cluster cd._dist = minSqr; // Record square-distance return cd; // Return for flow-coding } // KMeans++ re-clustering private static double[][] recluster(double[][] points, Random rand, int N, Initialization init, String[][] isCats) { double[][] res = new double[N][]; res[0] = points[0]; int count = 1; ClusterDist cd = new ClusterDist(); switch( init ) { case Random: break; case PlusPlus: { // k-means++ while( count < res.length ) { double sum = 0; for (double[] point1 : points) sum += minSqr(res, point1, isCats, cd, count); for (double[] point : points) { if (minSqr(res, point, isCats, cd, count) >= rand.nextDouble() * sum) { res[count++] = point; break; } } } break; } case Furthest: { // Takes cluster center further from any already chosen ones while( count < res.length ) { double max = 0; int index = 0; for( int i = 0; i < points.length; i++ ) { double sqr = minSqr(res, points[i], isCats, cd, count); if( sqr > max ) { max = sqr; index = i; } } res[count++] = points[index]; } break; } default: throw H2O.fail(); } return res; } private void randomRow(Vec[] vecs, Random rand, double[] center, double[] means, double[] mults, int[] modes) { long row = Math.max(0, (long) (rand.nextDouble() * vecs[0].length()) - 1); data(center, vecs, row, means, mults, modes); } // Pick most common cat level for each cluster_centers' cat columns private static double[][] max_cats(double[][] centers, long[][][] cats, String[][] isCats) { for( int clu = 0; clu < centers.length; clu++ ) for( int col = 0; col < centers[0].length; col++ ) if( isCats[col] != null ) centers[clu][col] = ArrayUtils.maxIndex(cats[clu][col]); return centers; } private static double[][] destandardize(double[][] centers, String[][] isCats, double[] means, double[] mults) { int K = centers.length; int N = centers[0].length; double[][] value = new double[K][N]; for( int clu = 0; clu < K; clu++ ) { System.arraycopy(centers[clu],0,value[clu],0,N); if( mults!=null ) { // Reverse standardization for( int col = 0; col < N; col++) if( isCats[col] == null ) value[clu][col] = value[clu][col] / mults[col] + means[col]; } } return value; } private static void data(double[] values, Vec[] vecs, long row, double[] means, double[] mults, int[] modes) { for( int i = 0; i < values.length; i++ ) { values[i] = Kmeans_preprocessData(vecs[i].at(row), i, means, mults, modes); } } private static void data(double[] values, Chunk[] chks, int row, double[] means, double[] mults, int[] modes) { for( int i = 0; i < values.length; i++ ) { values[i] = Kmeans_preprocessData(chks[i].atd(row), i, means, mults, modes); } } /** * This helper creates a ModelMetricsClustering from a trained model * @param model, must contain valid statistics from training, such as _betweenss etc. */ private ModelMetricsClustering makeTrainingMetrics(KMeansModel model) { ModelMetricsClustering mm = new ModelMetricsClustering(model, train()); mm._size = model._output._size; mm._withinss = model._output._withinss; mm._betweenss = model._output._betweenss; mm._totss = model._output._totss; mm._tot_withinss = model._output._tot_withinss; model.addMetrics(mm); return mm; } private static class SplitTask extends MRTask<SplitTask> { // IN double[][] _centers; double[] _means, _mults; // Standardization int[] _modes; // Imputation of missing categoricals final int _k; final String[][] _isCats; final boolean _hasWeight; final int _clusterToSplit; final int _dimToSplit; final double _splitPoint; // OUT double[][] _cMeans; // Means for each cluster long[] _size; // Number of rows in each cluster SplitTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight, int clusterToSplit, int dimToSplit, double splitPoint) { _centers = centers; _means = means; _mults = mults; _modes = modes; _isCats = isCats; _k = k; _hasWeight = hasWeight; _clusterToSplit = clusterToSplit; _dimToSplit = dimToSplit; _splitPoint = splitPoint; } @Override public void map(Chunk[] cs) { int N = cs.length - (_hasWeight ? 1:0) - 1 /*clusterassignment*/; assert _centers[0].length==N; _cMeans = new double[_k][N]; _size = new long[_k]; Chunk assignment = cs[cs.length-1]; // Find closest cluster center for each row double[] values = new double[N]; // Temp data to hold row as doubles ClusterDist cd = new ClusterDist(); for( int row = 0; row < cs[0]._len; row++ ) { if (assignment.at8(row) != _clusterToSplit) continue; double weight = _hasWeight ? cs[N].atd(row) : 1; if (weight == 0) continue; //skip holdout rows assert(weight == 1); //K-Means only works for weight 1 (or weight 0 for holdout) data(values, cs, row, _means, _mults, _modes); // Load row as doubles assert (_isCats[_dimToSplit]==null); if (values[_dimToSplit] > _centers[_clusterToSplit][_dimToSplit]) { cd._cluster = _centers.length-1; assignment.set(row, cd._cluster); } else { cd._cluster = _clusterToSplit; } int clu = cd._cluster; assert clu != -1; // No broken rows // Add values and increment counter for chosen cluster for( int col = 0; col < N; col++ ) _cMeans[clu][col] += values[col]; // Sum the column centers _size[clu]++; } // Scale back down to local mean for( int clu = 0; clu < _k; clu++ ) if( _size[clu] != 0 ) ArrayUtils.div(_cMeans[clu], _size[clu]); _centers = null; _means = _mults = null; _modes = null; } @Override public void reduce(SplitTask mr) { for( int clu = 0; clu < _k; clu++ ) { long ra = _size[clu]; long rb = mr._size[clu]; double[] ma = _cMeans[clu]; double[] mb = mr._cMeans[clu]; for( int c = 0; c < ma.length; c++ ) // Recursive mean if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb); } ArrayUtils.add(_size, mr._size); } } }