package net.varkhan.data.learn.decision;
import junit.framework.TestCase;
import net.varkhan.base.functor.Mapper;
import net.varkhan.base.functor.curry.Pair;
import net.varkhan.base.functor.generator.GaussianNumberGenerator;
import net.varkhan.base.functor.generator.UniformPRNGDef;
import net.varkhan.data.learn.stats.InformationGain;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Random;
/**
* <b></b>.
* <p/>
*
* @author varkhan
* @date 1/25/14
* @time 10:11 PM
*/
public class DecisionLearnerTest extends TestCase {
protected static Mapper<String,Double,Object> bucket=new Mapper<String,Double,Object>() {
@Override
public String invoke(Double arg, Object ctx) {
if(arg==null) return null;
double v = arg;
if(v<0) return "]...0]";
if(v<1) return "[0..1]";
if(v<2) return "[1..2]";
if(v<3) return "[2..3]";
if(v<4) return "[3..4]";
return "[4...[";
}
};
protected static Iterable<Pair<Double,String>> sample(final Random rand, final int num, final double err) {
return new Iterable<Pair<Double,String>>() {
@Override
public Iterator<Pair<Double,String>> iterator() {
final GaussianNumberGenerator<Object> rng=new GaussianNumberGenerator<Object>(new UniformPRNGDef(rand), 0, err);
return new Iterator<Pair<Double,String>>() {
protected volatile int cnt=0;
@Override
public boolean hasNext() {
return cnt<num;
}
@Override
public Pair<Double,String> next() {
cnt ++;
double v = rand.nextDouble()*10-3;
double b = v + rng.invoke(null);
String s = bucket.invoke(b,null);
return new Pair.Value<Double,String>(v, s);
}
@Override
public void remove() { }
};
}
};
}
protected static DiscretePartitionFactory<String,Long,Double,Object> createSplitPartition(final int s) {
return new DiscretePartitionFactory<String,Long,Double,Object>(
new Mapper<Long,Double,Object>() {
@Override
public Long invoke(Double arg, Object ctx) {
return arg==null ? null : arg>s ? Long.valueOf(+1) : arg<s ? Long.valueOf(-1) : Long.valueOf(0);
}
@Override
public String toString() {
return "<"+s+">($)";
}
},
3,
new InformationGain<String,Object>()
);
}
public void testTrainInc100() throws Exception {
Random rand=new Random();
testTrainInc(rand, 100, 100, 0.03, 0.05);
testTrainInc(rand, 100, 100, 0.05, 0.10);
testTrainInc(rand, 100, 100, 0.10, 0.15);
testTrainInc(rand, 100, 100, 0.20, 0.20);
testTrainInc(rand, 100, 100, 0.30, 0.25);
testTrainInc(rand, 100, 100, 0.40, 0.30);
}
public void testTrainInc300() throws Exception {
Random rand=new Random();
testTrainInc(rand, 30, 300, 0.03, 0.02);
testTrainInc(rand, 30, 300, 0.05, 0.03);
testTrainInc(rand, 30, 300, 0.10, 0.10);
testTrainInc(rand, 30, 300, 0.20, 0.20);
testTrainInc(rand, 30, 300, 0.30, 0.27);
testTrainInc(rand, 30, 300, 0.40, 0.30);
}
public void testTrainInc1000() throws Exception {
Random rand=new Random();
testTrainInc(rand, 10, 1000, 0.10, 0.05);
testTrainInc(rand, 10, 1000, 0.20, 0.15);
testTrainInc(rand, 10, 1000, 0.30, 0.25);
testTrainInc(rand, 10, 1000, 0.40, 0.30);
}
public void testTrainInc3000() throws Exception {
Random rand=new Random();
// Not much improvement is to be expected at these noise levels by just increasing the # of samples
testTrainInc(rand, 10, 3000, 0.20, 0.15);
testTrainInc(rand, 10, 3000, 0.30, 0.25);
testTrainInc(rand, 10, 3000, 0.40, 0.30);
}
public void testTrainAll100() throws Exception {
Random rand=new Random();
testTrainAll(rand, 100, 100, 0.03, 0.05);
testTrainAll(rand, 100, 100, 0.05, 0.10);
testTrainAll(rand, 100, 100, 0.10, 0.15);
testTrainAll(rand, 100, 100, 0.20, 0.25);
testTrainAll(rand, 100, 100, 0.30, 0.25);
testTrainAll(rand, 100, 100, 0.40, 0.30);
}
public void testTrainAll300() throws Exception {
Random rand=new Random();
testTrainAll(rand, 30, 300, 0.03, 0.02);
testTrainAll(rand, 30, 300, 0.05, 0.03);
testTrainAll(rand, 30, 300, 0.10, 0.10);
testTrainAll(rand, 30, 300, 0.20, 0.20);
testTrainAll(rand, 30, 300, 0.30, 0.30);
testTrainAll(rand, 30, 300, 0.40, 0.35);
}
public void testTrainAll1000() throws Exception {
Random rand=new Random();
testTrainAll(rand, 10, 1000, 0.10, 0.10);
testTrainAll(rand, 10, 1000, 0.20, 0.20);
testTrainAll(rand, 10, 1000, 0.30, 0.25);
testTrainAll(rand, 10, 1000, 0.40, 0.35);
}
public void testTrainAll3000() throws Exception {
Random rand=new Random();
// Not much improvement is to be expected at these noise levels by just increasing the # of samples
testTrainAll(rand, 10, 3000, 0.20, 0.20);
testTrainAll(rand, 10, 3000, 0.30, 0.25);
testTrainAll(rand, 10, 3000, 0.40, 0.35);
}
public void testTrainInc(Random rand, int run, int num, double err, double max) throws Exception {
double inv = 0;
for(int i=0; i<run; i++) {
inv += runTrainInc(rand, num, err);
}
inv /= run;
System.out.printf("Inc for %3d runs of %5d +/-%4f = %5f / %5f\n",run,num,err,inv,max);
assertTrue(String.format("Inc for %3d runs of %5d +/-%4f = %5f / %5f\n",run,num,err,inv,max),inv<max);
}
public void testTrainAll(Random rand, int run, int num, double err, double max) throws Exception {
double inv = 0;
for(int i=0; i<run; i++) {
inv += runTrainAll(rand, num, err);
}
inv /= run;
System.out.printf("All for %3d runs of %5d +/-%4f = %5f / %5f\n",run,num,err,inv,max);
assertTrue(String.format("All for %3d runs of %5d +/-%4f = %5f / %5f\n",run,num,err,inv,max),inv<max);
}
public double runTrainInc(Random rand, int num, double err) throws Exception {
DecisionLearner<String,Double,Object> learner = new DecisionLearner<String,Double,Object>(
0.7, // Min confidence
10, // Max depth
createSplitPartition(0),
createSplitPartition(1),
createSplitPartition(2),
createSplitPartition(3),
createSplitPartition(4)
);
for(Pair<Double,String> obs: sample(rand, num, err)) {
learner.train(obs.lvalue(),obs.rvalue(), null);
// System.out.println("Tree:\n"+learner.model().toString());
}
DecisionTree<String,Double,Object> model=learner.model();
// System.out.println("Tree for "+num+" +/-"+err+":\n"+model.toString());
int tst = 0;
int inv = 0;
for(double v=-3;v<7;v+=0.1) {
if(!bucket.invoke(v,null).equals(model.invoke(v,null))) inv++;
// System.out.printf("%4f:\t%s\t%s\n",v,bucket.invoke(v,null),model.invoke(v,null));
tst ++;
}
return (double)inv/(double)tst;
}
public double runTrainAll(Random rand, int num, double err) throws Exception {
DecisionLearner<String,Double,Object> learner = new DecisionLearner<String,Double,Object>(
0.7, // Min confidence
10, // Max depth
createSplitPartition(0),
createSplitPartition(1),
createSplitPartition(2),
createSplitPartition(3),
createSplitPartition(4)
);
learner.train(sample(rand, num, err), null);
DecisionTree<String,Double,Object> model=learner.model();
// System.out.println("Tree for "+num+" +/-"+err+":\n"+model.toString());
int tst = 0;
int inv = 0;
for(double v=-3;v<7;v+=0.1) {
if(!bucket.invoke(v,null).equals(model.invoke(v,null))) inv++;
// System.out.printf("%4f:\t%s\t%s\n",v,bucket.invoke(v,null),model.invoke(v,null));
tst ++;
}
return (double)inv/(double)tst;
}
}