package hex; import hex.ClusteringModel.ClusteringOutput; import hex.ClusteringModel.ClusteringParameters; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.util.ArrayUtils; import water.util.TwoDimTable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class ModelMetricsClustering extends ModelMetricsUnsupervised { public long[/*k*/] _size; public double[/*k*/] _withinss; public double _totss; public double _tot_withinss; public double _betweenss; // public TwoDimTable _centroid_stats; public double totss() { return _totss; } public double tot_withinss() { return _tot_withinss; } public double betweenss() { return _betweenss; } public ModelMetricsClustering(Model model, Frame frame) { super(model, frame, 0, Double.NaN); _size = null; _withinss = null; _totss = _tot_withinss = _betweenss = Double.NaN; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); sb.append(" total sum of squares: " + (float)_totss + "\n"); sb.append(" total within sum of squares: " + (float)_tot_withinss + "\n"); sb.append(" total between sum of squares: " + (float)_betweenss + "\n"); if (_size != null) sb.append(" per cluster sizes: " + Arrays.toString(_size) + "\n"); if (_withinss != null) sb.append(" per cluster within sum of squares: " + Arrays.toString(_withinss) + "\n"); return sb.toString(); } /** * Populate TwoDimTable from members _size and _withinss * @return TwoDimTable */ public TwoDimTable createCentroidStatsTable() { if (_size == null || _withinss == null) return null; List<String> colHeaders = new ArrayList<>(); List<String> colTypes = new ArrayList<>(); List<String> colFormat = new ArrayList<>(); colHeaders.add("Centroid"); colTypes.add("long"); colFormat.add("%d"); colHeaders.add("Size"); colTypes.add("double"); colFormat.add("%.5f"); colHeaders.add("Within Cluster Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f"); final int K = _size.length; assert(_withinss.length == K); TwoDimTable table = new TwoDimTable( "Centroid Statistics", null, new String[K], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), ""); for (int k =0; k<K; ++k) { int col = 0; table.set(k, col++, k+1); table.set(k, col++, _size[k]); table.set(k, col, _withinss[k]); } return table; } public static class MetricBuilderClustering extends MetricBuilderUnsupervised<MetricBuilderClustering> { public long[] _size; // Number of elements in cluster public double[] _within_sumsqe; // Within-cluster sum of squared error private double[/*features*/] _colSum; // Sum of each column private double[/*features*/] _colSumSq; // Sum of squared values of each column public MetricBuilderClustering(int ncol, int nclust) { _work = new double[ncol]; _size = new long[nclust]; _within_sumsqe = new double[nclust]; Arrays.fill(_size, 0); Arrays.fill(_within_sumsqe, 0); _colSum = new double[ncol]; _colSumSq = new double[ncol]; Arrays.fill(_colSum, 0); Arrays.fill(_colSumSq, 0); } // Compare row (dataRow) against centroid it was assigned to (preds[0]) @Override public double[] perRow(double[] preds, float[] dataRow, Model m) { assert m instanceof ClusteringModel; assert !Double.isNaN(preds[0]); ClusteringModel clm = (ClusteringModel) m; boolean standardize = ((((ClusteringOutput) clm._output)._centers_std_raw) != null); double[][] centers = standardize ? ((ClusteringOutput) clm._output)._centers_std_raw: ((ClusteringOutput) clm._output)._centers_raw; double[] sub = standardize ? ((ClusteringOutput) clm._output)._normSub : null; double[] mul = standardize ? ((ClusteringOutput) clm._output)._normMul : null; int clus = (int)preds[0]; double [] colSum = new double[_colSum.length]; double [] colSumSq = new double[_colSumSq.length]; double sqr = hex.genmodel.GenModel.KMeans_distance(centers[clus], dataRow, ((ClusteringOutput) clm._output)._mode, colSum, colSumSq); // System.out.println(Arrays.toString(colSumSq)); ArrayUtils.add(_colSum, colSum); ArrayUtils.add(_colSumSq, colSumSq); _count++; _size[clus]++; _sumsqe += sqr; _within_sumsqe[clus] += sqr; if (Double.isNaN(_sumsqe)) throw new H2OIllegalArgumentException("Sum of Squares is invalid (Double.NaN) - Check for missing values in the dataset."); return preds; // Flow coding } @Override public void reduce(MetricBuilderClustering mm) { super.reduce(mm); ArrayUtils.add(_size, mm._size); ArrayUtils.add(_within_sumsqe, mm._within_sumsqe); ArrayUtils.add(_colSum, mm._colSum); ArrayUtils.add(_colSumSq, mm._colSumSq); } @Override public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) { assert m instanceof ClusteringModel; ClusteringModel clm = (ClusteringModel) m; ModelMetricsClustering mm = new ModelMetricsClustering(m, f); mm._size = _size; mm._tot_withinss = _sumsqe; mm._withinss = new double[_size.length]; for (int i = 0; i < mm._withinss.length; i++) mm._withinss[i] = _within_sumsqe[i]; long numRows = f.numRows(); if( m._parms._weights_column != null) numRows = _count; // Sum-of-square distance from grand mean if ( ((ClusteringParameters) clm._parms)._k == 1 ) mm._totss = mm._tot_withinss; else { mm._totss = 0; for (int i = 0; i < _colSum.length; i++) { if(((ClusteringOutput)clm._output)._mode[i] == -1) mm._totss += _colSumSq[i] - (_colSum[i] * _colSum[i]) / numRows; else mm._totss += _colSum[i]; // simply add x[i] != modes[i] for categoricals } } mm._betweenss = mm._totss - mm._tot_withinss; return m.addMetrics(mm); } } }