/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.clustering;
import rapaio.core.stat.Mean;
import rapaio.data.*;
import rapaio.data.filter.Filters;
import rapaio.ml.common.distance.Distance;
import rapaio.ml.common.distance.KMeansInitMethod;
import rapaio.printer.Printable;
import rapaio.sys.WS;
import rapaio.util.Pair;
import rapaio.util.Tag;
import rapaio.printer.Summary;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import static java.util.stream.Collectors.toList;
import static rapaio.core.CoreTools.*;
/**
* KMeans clusterization algorithm
*
* @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>
*/
public class KMeans implements Printable {
private int k = 2;
private int runs = Integer.MAX_VALUE;
private Tag<KMeansInitMethod> init = KMeansInitMethod.FORGY;
private Tag<Distance> distance = Distance.EUCLIDEAN;
private Consumer<KMeans> runningHook = null;
private Frame summary;
private double eps = 1e-20;
private boolean learned = false;
private boolean debug = false;
// clustering artifacts
private String[] inputs;
private Frame centroids;
private int[] arrows;
private Numeric errors;
private Map<Integer, Numeric> clusterErrors;
// summary artifacts
private Numeric summaryAllDist;
public KMeans withK(int k) {
this.k = k;
return this;
}
public KMeans withEps(double eps) {
this.eps = eps;
return this;
}
public KMeans withRuns(int runs) {
this.runs = runs;
return this;
}
public KMeans withDebug(boolean debug) {
this.debug = debug;
return this;
}
public final KMeans withRunningHook(Consumer<KMeans> hook) {
runningHook = hook;
return this;
}
public void cluster(Frame df, String... varNames) {
validate(df, varNames);
inputs = VRange.of(varNames).parseVarNames(df).stream().toArray(String[]::new);
centroids = init.get().init(df, inputs, k);
arrows = new int[df.rowCount()];
errors = Numeric.empty().withName("errors");
clusterErrors = new HashMap<>();
Index.seq(k).stream().forEach(c -> clusterErrors.put(c.index(), Numeric.empty().withName("c" + (c.index() + 1) + "_errors")));
assignToCentroids(df);
int rounds = runs;
while (rounds-- > 0) {
recomputeCentroids(df);
assignToCentroids(df);
if (runningHook != null) {
runningHook.accept(this);
}
int erc = errors.rowCount();
if (erc > 1 && Math.abs(errors.value(erc - 1) - errors.value(erc - 2)) < eps) {
break;
}
}
buildSummary(df);
learned = true;
}
private void validate(Frame df, String... varNames) {
List<String> nameList = VRange.of(varNames).parseVarNames(df);
for (String varName : nameList) {
if (!df.var(varName).type().isNumeric())
throw new IllegalArgumentException("all matched vars must be numeric: check var " + varName);
if (df.var(varName).stream().complete().count() != df.rowCount()) {
throw new IllegalArgumentException("all matched vars must have non-missing values: check var " + varName);
}
}
}
private void assignToCentroids(Frame df) {
if (debug) WS.println("assignToCentroids called ..");
double totalError = 0.0;
double[] err = new double[centroids.rowCount()];
List<Pair<Integer, Double>> pairs = IntStream.range(0, df.rowCount()).parallel().boxed().map(i -> {
double d = Double.NaN;
int cluster = -1;
for (int j = 0; j < centroids.rowCount(); j++) {
double dd = distance.get().distance(df, i, centroids, j, inputs);
if (!Double.isFinite(dd)) continue;
if (Double.isNaN(dd)) continue;
if (!Double.isNaN(d)) {
if (dd < d) {
d = dd;
cluster = j;
}
} else {
d = dd;
cluster = j;
}
}
if (cluster == -1) {
throw new RuntimeException("cluster could not be computed");
}
double error = Math.pow(d, 2);
arrows[i] = cluster;
return Pair.from(cluster, error);
}).collect(toList());
for (Pair<Integer, Double> p : pairs) {
totalError += p._2;
err[p._1] += p._2;
}
for (int i = 0; i < err.length; i++) {
clusterErrors.get(i).addValue(err[i]);
}
errors.addValue(totalError);
}
private void recomputeCentroids(Frame df) {
if (debug) WS.println("recomputeCentroids called ..");
Var[] means = IntStream.range(0, k).boxed().map(i -> Numeric.fill(df.rowCount(), 0)).toArray(Numeric[]::new);
for (String input : inputs) {
for (int j = 0; j < k; j++) {
means[j].clear();
}
for (int j = 0; j < df.rowCount(); j++) {
means[arrows[j]].addValue(df.value(j, input));
}
for (int j = 0; j < k; j++) {
if (means[j].rowCount() == 0)
continue;
double mean = Mean.from(means[j]).value();
centroids.setValue(j, input, mean);
}
}
}
public Var getClusterAssignment() {
Var var = Index.empty(arrows.length);
for (int i = 0; i < arrows.length; i++) {
var.setIndex(i, arrows[i] + 1);
}
return var;
}
public Numeric getRunningErrors() {
return errors.solidCopy();
}
public double getError() {
return errors.rowCount() == 0 ? Double.NaN : errors.value(errors.rowCount() - 1);
}
public Numeric getRunningClusterError(int c) {
if (c >= k)
throw new IllegalArgumentException("cluster " + c + " does not exists");
return clusterErrors.get(c);
}
public double getClusterError(int c) {
if (c >= k)
throw new IllegalArgumentException("cluster " + c + " does not exists");
return clusterErrors.get(c).value(clusterErrors.get(c).rowCount() - 1);
}
private void buildSummary(Frame df) {
Index summaryId = Index.seq(1, centroids.rowCount() + 1).withName("ID");
Index summaryCount = Index.fill(centroids.rowCount(), 0).withName("count");
Numeric summaryMean = Numeric.fill(centroids.rowCount(), 0).withName("mean");
Numeric summaryVar = Numeric.fill(centroids.rowCount(), 0).withName("var");
Numeric summaryVarP = Numeric.fill(centroids.rowCount(), 0).withName("var/total");
Numeric summarySd = Numeric.fill(centroids.rowCount(), 0).withName("sd");
summaryAllDist = Numeric.empty().withName("all dist");
Map<Integer, Numeric> distances = new HashMap<>();
for (int i = 0; i < df.rowCount(); i++) {
double d = distance.get().distance(centroids, arrows[i], df, i, inputs);
if (!distances.containsKey(arrows[i]))
distances.put(arrows[i], Numeric.empty());
distances.get(arrows[i]).addValue(d);
summaryAllDist.addValue(d);
}
double tvar = var(summaryAllDist).value();
for (Map.Entry<Integer, Numeric> e : distances.entrySet()) {
summaryCount.setIndex(e.getKey(), e.getValue().rowCount());
summaryMean.setValue(e.getKey(), mean(e.getValue()).value());
double v = var(e.getValue()).value();
summaryVar.setValue(e.getKey(), v);
summaryVarP.setValue(e.getKey(), v / tvar);
summarySd.setValue(e.getKey(), Math.sqrt(v));
}
summary = SolidFrame.byVars(summaryId, summaryCount, summaryMean, summaryVar, summaryVarP, summarySd);
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
sb.append("KMeans clustering model\n");
sb.append("=======================\n");
sb.append("\n");
sb.append("Parameters: \n");
sb.append("> K = ").append(k).append("\n");
sb.append("> init = ").append(init.name()).append("\n");
sb.append("> distance = ").append(distance.name()).append("\n");
sb.append("> eps = ").append(eps).append("\n");
sb.append("> debug = ").append(debug).append("\n");
sb.append("\n");
sb.append("Learned clusters\n");
sb.append("----------------\n");
if (!learned) {
sb.append("KMeans did not clustered anything yet!\n");
} else {
sb.append("Overall: \n");
sb.append("> count: ").append(summaryAllDist.rowCount()).append("\n");
sb.append("> mean: ").append(WS.formatFlex(mean(summaryAllDist).value())).append("\n");
sb.append("> var: ").append(WS.formatFlex(var(summaryAllDist).value())).append("\n");
sb.append("> sd: ").append(WS.formatFlex(var(summaryAllDist).sdValue())).append("\n");
sb.append("\n");
sb.append("Per cluster: \n");
sb.append(Summary.headString(Filters.refSort(summary, summary.var("count").refComparator(false))));
}
return sb.toString();
}
}