package net.varkhan.data.learn.cluster;
import net.varkhan.base.functor.Functional;
import net.varkhan.base.functor._;
import net.varkhan.data.learn.distance.Distance;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 12/7/13
* @time 1:18 PM
*/
public abstract class GaussianClustering<K extends GaussianClustering.Cluster<T,C>,T,C> implements Clustering<K,T,C> {
protected final Distance<T,C> ds;
protected Clusters<K,T,C> cs;
public GaussianClustering(Functional<_<T,_<T,_>>,C> ds) {
this(Distance.wrap(ds));
}
public GaussianClustering(Distance<T,C> ds) {
this.ds=ds;
}
@Override
public Distance<T,C> distance() {
return ds;
}
@Override
public Clusters<K,T,C> model() {
return cs;
}
public static class Cluster<T,C> implements Clustering.Cluster<T,C> {
protected final Distance<T,C> ds;
/** The list of allpoints in the cluster */
protected final List<T> values=new ArrayList<T>();
/** The minimum of the sum of distances between one point and all other points {@code min(v, distsum(v, ...))} */
protected double mindds=Double.MAX_VALUE;
/** The arg-min of the sum of distances between one point and all other points {@code argmin(v, distsum(v, ...))} */
protected T center=null;
/** The sum for all pairs of points of their distance {@code sum(v, distsum(v, ...))} */
protected double sumdds=0;
public Cluster(Distance<T,C> ds) {
this.ds=ds;
}
protected void add(T val, C ctx) {
double sd = distsum(val, ctx);
values.add(val);
// No previous center?
if(center==null) {
center = val;
mindds = Double.MAX_VALUE;
sumdds += sd;
}
// This total distance is smaller than the previous one
else if(mindds>sd) {
center = val;
mindds = sd;
sumdds += sd;
}
else {
// This total distance is smaller than the previous one updated by the new point
double d = ds.invoke(val, center, ctx);
if(mindds+d>sd) {
center = val;
mindds = sd;
sumdds += sd;
}
// OK, not sure so we have to recompute everything
else update(ctx);
}
}
protected void update(C ctx) {
double cd = Double.MAX_VALUE;
double sd = 0;
T cv = null;
for(T v : values) {
double d = distsum(v, ctx);
sd += d;
if(cd<d) {
cd = d;
cv = v;
}
}
center = cv;
mindds = cd;
sumdds = sd;
}
public double distsum(T val, C ctx) {
double sd = 0;
for(T v: values) {
double d = ds.invoke(val,v,ctx);
sd += d;
}
return sd;
}
public double distance(T val, C ctx) {
if(center==null) return ds.invoke(val,val, ctx);
return ds.invoke(val,center,ctx);
}
@Override
public T center() {
return center;
}
@Override
public double confidence(T obs, C ctx) {
if(center==null) return 0;
double d = ds.invoke(obs, center, ctx);
return Math.exp(-d/(2*mindds))/Math.sqrt(2*mindds*Math.PI);
}
@Override
public double diameter(double p, C ctx) {
if(p>=1) return 0;
if(p<=0) return Double.MAX_VALUE;
return -(2*mindds)*Math.log(p*Math.sqrt(2*mindds*Math.PI));
}
}
public static class Clusters<K extends Cluster<T,C>,T,C> implements Clustering.Clusters<K,T,C> {
protected final List<K> cset = new ArrayList<K>();
protected Clusters() { }
protected Clusters(Collection<K> cls) { cset.addAll(cls); }
@Override
public Collection<K> clusters() {
return Collections.unmodifiableCollection(cset);
}
public void add(K c) {
cset.add(c);
}
public int size() {
return cset.size();
}
@Override
public K invoke(T obs, C ctx) {
K cm = null;
double dm = Double.MAX_VALUE;
for(K c: cset) {
double d = c.distance(obs, ctx);
if(dm>d) {
dm = d;
cm = c;
}
}
return cm;
}
@Override
public double confidence(K cls, T obs, C ctx) {
return cls.confidence(obs, ctx);
}
}
}