/*******************************************************************************
* Copyright (C) 2006-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.clustering;
import java.util.Arrays;
import weka.clusterers.Clusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Instances;
/**
* An interface for use with DomainLearner that contains a function that, when given
* a WEKA clusterer, returns an array of cluster names
* @author Dominik Jain
*/
public interface ClusterNamer<Cl extends Clusterer> {
public String[] getNames(Cl clusterer) throws Exception;
/**
* the most basic cluster namer, which simply adds a prefix to each cluster index
* @author Dominik Jain
*/
public static class SimplePrefix implements ClusterNamer<Clusterer> {
protected String prefix;
public SimplePrefix(String prefix) {
this.prefix = prefix;
}
public String[] getNames(Clusterer clusterer) throws Exception {
int n = clusterer.numberOfClusters();
String[] names = new String[n];
for(Integer i = 0; i < n; i++)
names[i] = prefix + i.toString();
return names;
}
}
/**
* a basic cluster namer that simply returns a fixed list of predetermined names
* @author Dominik Jain
*/
public static class Fixed implements ClusterNamer<Clusterer> {
protected String[] names;
public Fixed(String[] names) {
this.names = names;
}
public String[] getNames(Clusterer clusterer) throws Exception {
if(clusterer.numberOfClusters() != names.length)
throw new Exception("Number of clusters does not match number of names.");
return names;
}
}
/**
* a K-Means cluster namer which simply returns the string "~E +/- S" for each cluster, where
* E is the expected value and S the standard deviation of the cluster.
* @author Dominik Jain
*/
public static class MeanStdDev implements ClusterNamer<SimpleKMeans> {
public String[] getNames(SimpleKMeans clusterer) {
int numClusters = clusterer.getNumClusters();
String[] ret = new String[numClusters];
Instances centroids = clusterer.getClusterCentroids();
Instances stdDevs = clusterer.getClusterStandardDevs();
for(int i = 0; i < numClusters; i++)
ret[i] = String.format("~%.2f +/- %.2f", centroids.instance(i).value(0), stdDevs.instance(i).value(0));
return ret;
}
}
/**
* a K-Means cluster namer, which returns the range of values (i.e. an interval), formatted
* in a string, for each
* cluster by calculating the intersections of the Gaussian distributions
* @author Dominik Jain
*/
public static class Intervals implements ClusterNamer<SimpleKMeans> {
/**
* calculates the intersection of two Gaussian distributions
* @param e1 the expected value of the first distribution
* @param s1 the standard deviation of the first distribution
* @param e2 the expected value of the second distribution
* @param s2 the standard deviation of the second distribution
* @return the x-coordinate of the intersection
*/
public static double getIntersection(double e1, double s1, double e2, double s2) {
if(s2 == 0)
return e2;
if(s1 == s2)
return (e1 + e2) / 2;
double r1 = 1.0/2/(s1*s1-s2*s2)*(2*s1*s1*e2-2*s2*s2*e1+2*Math.sqrt(-2*s1*s1*e2*s2*s2*e1+s1*s1*s2*s2*e1*e1-2*s1*s1*s1*s1*Math.log(s2/s1)*s2*s2+s2*s2*s1*s1*e2*e2+2*s2*s2*s2*s2*Math.log(s2/s1)*s1*s1));
if((e1 <= r1 && r1 <= e2) || (e2 <= r1 && r1 <= e1))
return r1;
double r2 = 1.0/2/(s1*s1-s2*s2)*(2*s1*s1*e2-2*s2*s2*e1-2*Math.sqrt(-2*s1*s1*e2*s2*s2*e1+s1*s1*s2*s2*e1*e1-2*s1*s1*s1*s1*Math.log(s2/s1)*s2*s2+s2*s2*s1*s1*e2*e2+2*s2*s2*s2*s2*Math.log(s2/s1)*s1*s1));
return r2;
}
public String[] getNames(SimpleKMeans clusterer) {
int numClusters = clusterer.getNumClusters();
String[] ret = new String[numClusters];
double[] centroids = clusterer.getClusterCentroids().attributeToDoubleArray(0);
double[] stdDevs = clusterer.getClusterStandardDevs().attributeToDoubleArray(0);
double[] sortedCentroids = centroids.clone();
Arrays.sort(sortedCentroids);
int[] sortOrder = new int[numClusters];
for(int i = 0; i < numClusters; i++)
for(int j = 0; j < numClusters; j++)
if(centroids[j] == sortedCentroids[i])
sortOrder[i] = j;
for(int i = 0; i < numClusters; i++) {
int idx = sortOrder[i];
if(stdDevs[idx] == 0.0) { // no deviation -> no range
ret[idx] = String.format("%.2f", centroids[idx]);
continue;
}
if(i == 0) { // no left neighbour
ret[idx] = String.format("< %.2f (~%.2f)", getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[1]], stdDevs[sortOrder[1]]), centroids[idx]);
continue;
}
if(i == numClusters-1) { // no right neighbour
ret[idx] = String.format("> %.2f (~%.2f)", getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i-1]], stdDevs[sortOrder[i-1]]), centroids[idx]);
continue;
}
double left = getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i-1]], stdDevs[sortOrder[i-1]]);
double right = getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i+1]], stdDevs[sortOrder[i+1]]);
ret[idx] = String.format("%.2f - %.2f (~%.2f)", left, right, centroids[idx]);
}
return ret;
}
}
/**
* variant of Intervals that produces names that are compatible with SRL databases
* @author Dominik Jain
*/
public static class IntervalsPlain implements ClusterNamer<SimpleKMeans> {
protected static String strFloat(double f) {
String s = String.format("%.2f", f);
return s.replace('-', 'm').replace(",", "p");
}
public String[] getNames(SimpleKMeans clusterer) {
int numClusters = clusterer.getNumClusters();
String[] ret = new String[numClusters];
double[] centroids = clusterer.getClusterCentroids().attributeToDoubleArray(0);
double[] stdDevs = clusterer.getClusterStandardDevs().attributeToDoubleArray(0);
double[] sortedCentroids = centroids.clone();
Arrays.sort(sortedCentroids);
int[] sortOrder = new int[numClusters];
for(int i = 0; i < numClusters; i++)
for(int j = 0; j < numClusters; j++)
if(centroids[j] == sortedCentroids[i])
sortOrder[i] = j;
for(int i = 0; i < numClusters; i++) {
int idx = sortOrder[i];
if(stdDevs[idx] == 0.0) { // no deviation -> no range
ret[idx] = String.format("C_%s", strFloat(centroids[idx]));
continue;
}
if(i == 0) { // no left neighbour
ret[idx] = String.format("C_lt_%s_%s", strFloat(Intervals.getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[1]], stdDevs[sortOrder[1]])), strFloat(centroids[idx]));
continue;
}
if(i == numClusters-1) { // no right neighbour
ret[idx] = String.format("C_gt_%s_%s", strFloat(Intervals.getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i-1]], stdDevs[sortOrder[i-1]])), strFloat(centroids[idx]));
continue;
}
double left = Intervals.getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i-1]], stdDevs[sortOrder[i-1]]);
double right = Intervals.getIntersection(centroids[idx], stdDevs[idx], centroids[sortOrder[i+1]], stdDevs[sortOrder[i+1]]);
ret[idx] = String.format("C_%s_to_%s_%s", strFloat(left), strFloat(right), strFloat(centroids[idx]));
}
return ret;
}
}
}