package com.github.lwhite1.tablesaw.api.ml.clustering;
import com.github.lwhite1.tablesaw.api.Table;
import com.github.lwhite1.tablesaw.api.plot.Scatter;
import java.io.IOException;
import java.util.Arrays;
/**
*
*/
public class KmeansExample {
public static void main(String[] args) throws IOException {
Table t = Table.createFromCsv("data/whiskey.csv");
Kmeans model = new Kmeans(
5,
t.nCol(2),
t.nCol(3),
t.nCol(4),
t.nCol(5),
t.nCol(6),
t.nCol(7),
t.nCol(8),
t.nCol(9),
t.nCol(10),
t.nCol(11),
t.nCol(12),
t.nCol(13)
);
out("Distortion: " + model.distortion());
out("Cluster count: " + model.getClusterCount());
out(Arrays.toString(model.getClusterLabels()));
out(Arrays.toString(model.getClusterSizes()));
//out(model.clustered(t.column(1)).printHtml());
out(model.labeledCentroids().print());
int n = t.rowCount();
double[] kValues = new double[n - 2];
double[] distortions = new double[n - 2];
for (int k = 2; k < n; k++) {
kValues[k - 2] = k;
Kmeans kmeans = new Kmeans(k,
t.nCol(2),
t.nCol(3),
t.nCol(4),
t.nCol(5),
t.nCol(6),
t.nCol(7),
t.nCol(8),
t.nCol(9),
t.nCol(10),
t.nCol(11),
t.nCol(12),
t.nCol(13)
);
distortions[k - 2] = kmeans.distortion();
}
Scatter.show(kValues, "k", distortions, "distortion");
}
private static void out(Object object) {
System.out.println(String.valueOf(object));
}
}