package net.varkhan.data.learn.decision;
import net.varkhan.base.functor.Mapper;
import net.varkhan.base.functor.Ordinal;
import net.varkhan.base.functor.mapper.ConstMapper;
import net.varkhan.base.functor.ordinal.ConstOrdinal;
import net.varkhan.data.learn.Classifier;
import net.varkhan.data.learn.SingletonClassifier;
import java.util.*;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 12/26/13
* @time 7:32 PM
*/
public class DecisionTree<K,T,C> implements Decision.Tree<K,T,C> {
protected Node<K,?,T,C> tree;
public Node<K,?,T,C> tree() { return tree; }
public K invoke(T obs, C ctx) { return tree.invoke(obs, ctx); }
public Decision<K,T,C> decision(T obs, C ctx) {return tree.decision(obs, ctx);}
public double confidence(K key, T obs, C ctx) { return tree.confidence(key, obs, ctx); }
public Ordinal<T,C> partition() { return tree.partition(); }
public List<? extends Classifier<K,T,C>> classes() { return tree.classes(); }
public Collection<? extends Mapper<?,T,C>> attributes() { return tree.attributes(); }
// public Node<K,?,T,C> get(long... s) {
// Node<K,?,T,C> t = tree;
// if(s!=null) for(long i: s) {
// if(t.isLeaf()) throw new IllegalArgumentException("No node at "+LongArrays.toString(s));
// if(i>=t.partition().cardinal()) throw new IllegalArgumentException("No node at "+LongArrays.toString(s));
// t = ((Root<K,?,T,C>)t).get(i);
// }
// return t;
// }
//
// public void set(Node<K,?,T,C> n, long... s) {
// Root<K,?,T,C> p = null;
// long c = -1;
// Node<K,?,T,C> t = tree;
// if(s!=null) for(long i: s) {
// if(t.isLeaf()) throw new IllegalArgumentException("No node at "+LongArrays.toString(s));
// if(i>=t.partition().cardinal()) throw new IllegalArgumentException("No node at "+LongArrays.toString(s));
// p = (Root<K,?,T,C>)t;
// c = i;
// t = p.get(i);
// }
// if(p==null) {
// tree = n;
// }
// else {
// p.set(n,c);
// }
// }
public static interface Node<K,A,T,C> extends Decision.Tree<K,T,C> {
public K invoke(T obs, C ctx);
public double confidence(K key, T obs, C ctx);
public Collection<? extends Mapper<?,T,C>> attributes();
public Node<K,?,T,C> decision(T obs, C ctx);
public Partition<A,T,C> partition();
public boolean isLeaf();
public List<? extends Node<K,?,T,C>> children();
StringBuilder toString(StringBuilder buf, String ind, String tab, String sep);
}
public static class Leaf<K,A,T,C> extends SingletonClassifier<K,T,C> implements Node<K,A,T,C> {
protected final Partition<A,T,C> part;
protected Leaf(K key) { this(key,1); }
protected Leaf(K key, double cnf) {
super(key, cnf);
part = new Partition<A,T,C>((Mapper<A,T,C>) ConstMapper.NULL(), (Ordinal<A,C>) ConstOrdinal.UNITY(), cnf);
}
public Collection<? extends Mapper<?,T,C>> attributes() { return Collections.emptySet(); }
public Node<K,?,T,C> decision(T obs, C ctx) { return this; }
public Partition<A,T,C> partition() { return part; }
public List<? extends Node<K,?,T,C>> classes() { return Arrays.asList((Node<K, ?, T, C>) this); }
public boolean isLeaf() { return true; }
public List<Node<K,?,T,C>> children() { return Arrays.asList(); }
@Override
public StringBuilder toString(StringBuilder buf, String ind, String tab, String sep) {
return buf.append(super.toString()).append(sep);
}
@Override
public String toString() {
return super.toString();
}
}
public static class Root<K,A,T,C> implements Node<K,A,T,C> {
protected final Partition<A,T,C> part;
protected final Node<K,?,T,C>[] child;
@SuppressWarnings("unchecked")
protected Root(Partition<A,T,C> part) {
this.part = part;
this.child = new Node[(int)part.cardinal()];
}
protected void set(Node<K,?,T,C> n, long c) {
child[(int)c] = n;
}
protected Node<K,?,T,C> get(long c) {
return child[(int)c];
}
public Node<K,?,T,C> decision(T obs, C ctx) {
long s = part.invoke(obs, ctx);
if(s>=child.length) return null;
return s<child.length?child[(int)s]:null;
}
public K invoke(T obs, C ctx) {
Node<K,?,T,C> d = decision(obs, ctx);
return d==null?null:d.invoke(obs, ctx);
}
public double confidence(K key, T obs, C ctx) {
long s = part.invoke(obs, ctx);
if(s>=child.length) return 0;
Node<K,?,T,C> d = child[(int)s];
double c = d.confidence(key, obs, ctx);
for(Node<K,?,T,C> u: child) {
if(u!=d) c *= 1-u.confidence(key, obs, ctx);
}
return c;
}
public Collection<? extends Mapper<?,T,C>> attributes() {
Set<Mapper<?,T,C>> at = new HashSet<Mapper<?,T,C>>();
at.add(part.attribute());
for(Node<K,?,T,C> u: child) at.addAll(u.attributes());
return at;
}
public Partition<A,T,C> partition() { return part; }
public List<Node<K,?,T,C>> classes() { return Arrays.asList(child); }
public boolean isLeaf() { return false; }
public List<Node<K,?,T,C>> children() { return Arrays.asList(child); }
@Override
public StringBuilder toString(StringBuilder buf, String ind, String tab, String sep) {
buf.append(part.toString()).append(sep);
ind=ind+tab;
for(int i=0;i<child.length;i++) {
Node<K,?,T,C> u=child[i];
buf.append(ind).append(i).append(tab).append(": ");
u.toString(buf, ind, tab, sep);
}
return buf;
}
@Override
public String toString() {
return toString(new StringBuilder(), "","\t","\n").toString();
}
}
@Override
public String toString() {
return tree.toString(new StringBuilder(), "","\t","\n").toString();
}
}