package net.varkhan.data.learn.decision;
import net.varkhan.base.containers.map.ArrayOpenHashObj2LongMap;
import net.varkhan.base.containers.map.Obj2LongMap;
import net.varkhan.base.containers.set.ArrayOpenHashCountingSet;
import net.varkhan.base.containers.set.CountingSet;
import net.varkhan.base.functor.Mapper;
import net.varkhan.base.functor.Ordinal;
import net.varkhan.base.functor.curry.Pair;
import net.varkhan.data.learn.stats.Purity;
import java.util.*;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 12/28/13
* @time 1:33 PM
*/
public class DiscretePartitionFactory<K,A,T,C> implements Partition.Factory<K,A,T,C> {
protected static final Comparator<Collection> LARGEST=new Comparator<Collection>() {
@Override
public int compare(Collection o1, Collection o2) {
if(o1.size()>o2.size()) return -1;
if(o1.size()<o2.size()) return +1;
return 0;
}
};
protected final Mapper<A,T,C> attr;
protected final long card;
protected final Purity<K,C> pure;
public DiscretePartitionFactory(Mapper<A,T,C> attr, long card, Purity<K,C> pure) {
this.attr=attr;
this.card=card;
this.pure=pure;
}
public Partition<A,T,C> invoke(Iterable<? extends Pair<T,K>> obs, C ctx) {
Map<A,Collection<Pair<A,K>>> classes=new HashMap<A,Collection<Pair<A,K>>>();
for(Pair<T,K> o : obs) {
K k=o.rvalue();
T t=o.lvalue();
A a=attr.invoke(t, ctx);
Collection<Pair<A,K>> g=classes.get(a);
if(g==null) classes.put(a, g=new ArrayList<Pair<A,K>>());
g.add(new Pair.Value<A,K>(a,k));
}
// if(classes.size()<=1) return new Partition<A,T,C>((Mapper<A,T,C>) ConstMapper.NULL(), (Ordinal<A,C>) ConstOrdinal.UNITY(), 1.0);
List<Collection<Pair<A,K>>> parts=new ArrayList<Collection<Pair<A,K>>>(classes.values());
Collections.sort(parts, LARGEST);
CountingSet<K> all = new ArrayOpenHashCountingSet<K>();
Collection<CountingSet<K>> sets = new ArrayList<CountingSet<K>>();
CountingSet<K> def = new ArrayOpenHashCountingSet<K>();
final Obj2LongMap<A> index = new ArrayOpenHashObj2LongMap<A>(parts.size());
long idx = card;
if(idx>parts.size()) idx = parts.size();
final long card = idx;
for(Collection<Pair<A,K>> part: parts) {
idx --;
if(idx>0) {
CountingSet<K> set = new ArrayOpenHashCountingSet<K>();
// Take the card-1 largest parts as main bags, with non-zero indices
for(Pair<A,K> val : part) {
index.add(val.lvalue(), idx);
set.add(val.rvalue());
all.add(val.rvalue());
}
sets.add(set);
}
// For all other parts (indices 0 and <0) we add to the default bag
else {
CountingSet<K> set = new ArrayOpenHashCountingSet<K>();
for(Pair<A,K> val : part) {
index.add(val.lvalue(), 0);
set.add(val.rvalue());
def.add(val.rvalue());
all.add(val.rvalue());
}
}
}
sets.add(def);
double conf = pure.invoke(sets, all, ctx);
// System.out.println("Partitions "+card+" / "+parts.size()+" at "+conf+" on "+attr+" for \t"+obs+"\n\t"+sets);
return new Partition<A,T,C>(
attr,
new Ordinal<A,C>() {
@Override
public long cardinal() { return card; }
@Override
public long invoke(A arg, C ctx) {
long idx=index.getLong(arg);
// If 0, default bag or unknown value -- those get index 0
if(idx<=0) return 0;
return idx;
}
@Override
public String toString() {
StringBuilder buf = new StringBuilder();
buf.append("($?=");
boolean f = true;
for(Map.Entry<A,Long> e: (Iterable<? extends Map.Entry<A,Long>>) index) {
if(f) { f = false; buf.append(' '); }
else buf.append(", ");
buf.append(e.getKey()).append(": ").append(e.getValue());
}
buf.append(')');
return buf.toString();
}
},
conf
);
}
}