package net.varkhan.data.learn.decision;
import net.varkhan.base.containers.set.ArrayOpenHashCountingSet;
import net.varkhan.base.containers.set.CountingSet;
import net.varkhan.base.functor._;
import net.varkhan.base.functor.curry.Pair;
import net.varkhan.data.learn.SupervisedLearner;
import java.util.*;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 12/28/13
* @time 12:39 PM
*/
public class DecisionLearner<K,T,C> implements SupervisedLearner<K,T,C> {
protected final Collection<Partition.Factory<K,?,T,C>> attributes=new HashSet<Partition.Factory<K,?,T,C>>();
protected final Collection<Pair<T,K>> observed =new ArrayList<Pair<T,K>>();
protected final DecisionTree<K,T,C> tree =new DecisionTree<K,T,C>();
protected final double minc;
protected final long maxd;
public DecisionLearner(double minc, long maxd, Partition.Factory<K,?,T,C>... attr) {
this.minc=minc;
this.maxd=maxd;
for(Partition.Factory<K,?,T,C> a : attr) this.attributes.add(a);
}
public boolean train(Iterable<? extends _<T,? extends _<K,_>>> dat, C ctx) {
boolean m = false;
for(_<T,? extends _<K,_>> d: dat) m |= observed.add(new Pair.Value<T,K>(d));
if(!m) return false;
DecisionTree.Node<K,?,T,C> l = learnTree(tree.tree, observed, ctx, 0);
if(l!=null) {
tree.tree = l;
return true;
}
return false;
}
@Override
public boolean train(T obs, K key, C ctx) {
if(!observed.add(new Pair.Value<T,K>(obs, key))) return false;
DecisionTree.Node<K,?,T,C> l = learnPath(tree.tree, observed, obs, ctx, 0);
if(l!=null) {
tree.tree = l;
return true;
}
return false;
}
public boolean train(C ctx) {
DecisionTree.Node<K,?,T,C> l = learnTree(tree.tree, observed, ctx, 0);
if(l!=null) {
tree.tree = l;
return true;
}
return false;
}
/**
* Algo:
* - select an attribute to partition on
* - find the best partition for this attribute
* - if no attribute could be used or partition is not good enough, create a single leaf with majority key
* - otherwise create a root, and for each partitioned set of values, iterate down
* - set the node in parent to the created node
*
*/
protected DecisionTree.Node<K,?,T,C> learnPath(DecisionTree.Node<K,?,T,C> node, Collection<Pair<T,K>> values, T obs, C ctx, int l) {
Set<K> classes=classes(values);
if(classes.size()<=1) {
if(node!=null && node.isLeaf() && classes.contains(((DecisionTree.Leaf<K,?,T,C>) node).key())) {
// System.out.println("Preserving leaf at "+l+" for \t"+values+"\n\t"+node);
return null;
}
return subLeaf(values);
}
// System.out.println("Learning subtree at "+l+" on "+obs+" for "+values+"\n\t"+node);
@SuppressWarnings("unchecked")
Partition<Object,T,C> cand = (Partition<Object, T, C>) optPart(values, ctx);
if(node!=null) {
@SuppressWarnings("unchecked")
Partition<Object,T,C> part = (Partition<Object, T, C>) node.partition();
if(Partition.identical_(part, cand, values, ctx)) {
// System.out.println("Preserving subtree at "+l+" for \t"+values+"\n\t"+part);
if(node.isLeaf()) return null; // No change!
long s = part.invoke(obs, ctx);
@SuppressWarnings("unchecked")
DecisionTree.Root<K,?,T,C> root = (DecisionTree.Root<K,?,T,C>) node;
DecisionTree.Node<K,?,T,C> d = root.get(s);
List<List<Pair<T,K>>> groups = Partition.partition_(part, values, ctx);
DecisionTree.Node<K,?,T,C> n = learnPath(d, groups.get((int) s), obs, ctx, l+1);
if(n==null) return null; // No change!
if(n!=d) root.set(n, s);
return node;
}
}
return subTree(cand, values, ctx, l);
}
protected DecisionTree.Node<K,?,T,C> learnTree(DecisionTree.Node<K,?,T,C> node, Collection<Pair<T,K>> values, C ctx, int l) {
Set<K> classes=classes(values);
if(classes.size()<=1) {
if(node!=null && node.isLeaf() && classes.contains(((DecisionTree.Leaf<K,?,T,C>) node).key())) {
// System.out.println("Preserving leaf at "+l+" for \t"+values+"\n\t"+node);
return null;
}
return subLeaf(values);
}
// System.out.println("Learning subtree at "+l+" for "+values+"\n\t"+node);
@SuppressWarnings("unchecked")
Partition<Object,T,C> cand = (Partition<Object, T, C>) optPart(values, ctx);
if(node!=null) {
@SuppressWarnings("unchecked")
Partition<Object,T,C> part = (Partition<Object, T, C>) node.partition();
if(Partition.identical_(part, cand, values, ctx)) {
if(node.isLeaf()) return null; // No change!
@SuppressWarnings("unchecked")
DecisionTree.Root<K,?,T,C> root = (DecisionTree.Root<K,?,T,C>) node;
boolean c = false;
l ++;
for(long s=0; s<part.cardinal(); s++) {
DecisionTree.Node<K,?,T,C> d = root.get(s);
List<List<Pair<T,K>>> groups = Partition.partition_(part, values, ctx);
DecisionTree.Node<K,?,T,C> n = learnTree(d, groups.get((int) s), ctx, l);
if(n!=null) {
c = true;
if(n!=d) root.set(n, s);
}
}
return c?node:null;
}
// Fall through to re-learn the whole root
}
return subTree(cand, values, ctx, l);
}
protected DecisionTree.Node<K,?,T,C> subTree(Partition<?,T,C> part, Collection<Pair<T,K>> values, C ctx, int l) {
Set<K> classes=classes(values);
if(classes.size()<=1) {
return subLeaf(values);
}
// System.out.println("Computing subtree at "+l+" for \t"+values+"\n\t"+part);
if(part==null || part.confidence()<minc || l>maxd || part.cardinal()<=1) {
return subLeaf(values);
}
DecisionTree.Root<K,?,T,C> r = new DecisionTree.Root<K,Object,T,C>((Partition<Object, T, C>) part);
List<List<Pair<T,K>>> groups = Partition.partition_(part, values, ctx);
l ++;
for(int i=0; i<part.cardinal(); i++) {
List<Pair<T,K>> v=groups.get(i);
r.set(subTree(optPart(v, ctx), v, ctx, l),i);
}
return r;
}
protected Set<K> classes(Collection<Pair<T, K>> values) {
Set<K> c = new HashSet<K>();
for(Pair<T, K> p : values) {
c.add(p.rvalue());
}
return c;
}
protected DecisionTree.Node<K,?,T,C> subLeaf(Collection<Pair<T,K>> values) {
CountingSet<K> c = new ArrayOpenHashCountingSet<K>();
for(Pair<T,K> v: values) c.add(v.rvalue());
K b = null;
long m = -1;
for(K k: (Iterable<K>) c) {
long n=c.count(k);
if(m<n) {
m = n;
b = k;
}
}
// System.out.println("Creating leaf "+b+" for "+values);
return new DecisionTree.Leaf<K,Object,T,C>(b,m/(double)c.count());
}
protected Partition<?,T,C> optPart(Collection<Pair<T,K>> values, C ctx) {
Partition<?,T,C> part = null;
for(Partition.Factory<K,?,T,C> a: attributes) {
Partition<?,T,C> p = a.invoke(values, ctx);
if(part==null || p.confidence()>part.confidence()) part=p;
}
return part;
}
@Override
public DecisionTree<K,T,C> model() {
return tree;
}
}