package net.varkhan.data.learn.distance; import net.varkhan.base.containers.set.WeightingSet; /** * <b></b>. * <p/> * * @author varkhan * @date 12/7/13 * @time 5:13 PM */ public class KLDistance<T,C> extends Distance<WeightingSet<T>,C> { @Override @SuppressWarnings("unchecked") public double invoke(WeightingSet<T> lvalue, WeightingSet<T> rvalue, C ctx) { double nl = lvalue.weight(); double nr = rvalue.weight(); if(nl==0&&nr==0) return 0; if(nl==0||nr==0) return Double.POSITIVE_INFINITY; nl = 1.0/nl; nr = 1.0/nr; double kl = 0; for (T v : (Iterable<T>)lvalue) { double pl = nl * lvalue.weight(v); double pr = nr * rvalue.weight(v); if (pl == 0 || pr == 0) return Double.POSITIVE_INFINITY; else kl += (pl - pr) * Math.log(pl / pr); } for (T v : (Iterable<T>)rvalue) { double pl = nl * lvalue.weight(v); double pr = nr * rvalue.weight(v); if (pl == 0 || pr == 0) return Double.POSITIVE_INFINITY; else kl += (pl - pr) * Math.log(pl / pr); } return 0.5 * kl; } }