package net.varkhan.data.learn.distance;
import net.varkhan.base.containers.Iterator;
import net.varkhan.base.containers.set.WeightingSet;
import net.varkhan.base.functor.Functional;
import java.util.HashSet;
import java.util.Set;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 12/7/13
* @time 6:03 PM
*/
public class BackoffDistance<T,C> extends Distance<WeightingSet<T>,C> {
protected final Functional<T,C> bkf;
protected final Distance<WeightingSet<T>,C> dst;
public BackoffDistance(Functional<T,C> bkf, Distance<WeightingSet<T>,C> dst) {
this.bkf=bkf;
this.dst=dst;
}
@Override
@SuppressWarnings("unchecked")
public double invoke(WeightingSet<T> lvalue, WeightingSet<T> rvalue, C ctx) {
Set<T> k = new HashSet<T>();
for(T t: (Iterable<T>)lvalue) {
k.add(t);
}
for(T t: (Iterable<T>)rvalue) {
k.add(t);
}
double bkw = 0;
for(T t: k) bkw += bkf.invoke(t, ctx);
return dst.invoke(new BackoffSet<T, C>(lvalue, bkf, bkw, ctx), new BackoffSet<T, C>(rvalue, bkf, bkw, ctx), ctx);
}
protected static class BackoffSet<T,C> implements WeightingSet<T> {
protected final Functional<T,C> bkf;
protected final WeightingSet<T> obs;
protected final C ctx;
protected final double bkw;
public BackoffSet(WeightingSet<T> obs, Functional<T,C> bkf, double bkw, C ctx) {
this.bkf=bkf;
this.obs=obs;
this.ctx=ctx;
this.bkw=bkw;
}
public long size() { return obs.size(); }
public boolean isEmpty() { return obs.isEmpty(); }
public double weight() { return bkw+obs.weight(); }
public boolean has(T t) { return obs.has(t); }
public double weight(T t) { return bkf.invoke(t, ctx)+obs.weight(t); }
public void clear() { }
public boolean add(T t) { return false; }
public boolean add(T t, double wgh) { return false; }
public boolean del(T t) { return false; }
public Iterator<? extends T> iterator() { return obs.iterator(); }
public <Par> long visit(Visitor<T,Par> vis, Par par) { return obs.visit(vis, par); }
}
}