package rl.test;
import rl.EpsilonGreedyStrategy;
import rl.Policy;
import rl.PolicyIteration;
import rl.QLambda;
import rl.SarsaLambda;
import rl.SimpleMarkovDecisionProcess;
import rl.ValueIteration;
import shared.FixedIterationTrainer;
import shared.ThresholdTrainer;
/**
* A markov decision process test
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class MDPTest {
/**
* The main method
* @param args ignored
*/
public static void main(String[] args) {
// the andrew moore tutorial mdp
SimpleMarkovDecisionProcess mdp = new SimpleMarkovDecisionProcess();
mdp.setRewards(new double[] {0, 0, 10, 10});
mdp.setTransitionMatrices(new double[][][] {
{{ 1.0, 0, 0, 0}, {.5, .5, 0, 0 }},
{{ .5, 0, 0, .5}, { 0, 1, 0, 0 }},
{{ .5, 0, .5, 0}, {.5, .5, 0, 0 }},
{{ 0, 0, .5, .5}, {0, 1, 0, 0 }}});
mdp.setInitialState(0);
// solve it
ValueIteration vi = new ValueIteration(.9, mdp);
ThresholdTrainer tt = new ThresholdTrainer(vi);
long startTime = System.currentTimeMillis();
tt.train();
Policy p = vi.getPolicy();
long finishTime = System.currentTimeMillis();
System.out.println("Value iteration learned : " + p);
System.out.println("in " + tt.getIterations() + " iterations");
System.out.println("and " + (finishTime - startTime) + " ms");
PolicyIteration pi = new PolicyIteration(.9, mdp);
tt = new ThresholdTrainer(pi);
startTime = System.currentTimeMillis();
tt.train();
p = pi.getPolicy();
finishTime = System.currentTimeMillis();
System.out.println("Policy iteration learned : " + p);
System.out.println("in " + tt.getIterations() + " iterations");
System.out.println("and " + (finishTime - startTime) + " ms");
QLambda ql = new QLambda(.5, .9, .2, .995, new EpsilonGreedyStrategy(.3), mdp);
FixedIterationTrainer fit = new FixedIterationTrainer(ql, 100);
startTime = System.currentTimeMillis();
fit.train();
p = ql.getPolicy();
finishTime = System.currentTimeMillis();
System.out.println("Q lambda learned : " + p);
System.out.println("in " + 100 + " iterations");
System.out.println("and " + (finishTime - startTime) + " ms");
System.out.println("Acquiring " + ql.getTotalReward() + " reward");
SarsaLambda sl = new SarsaLambda(.5, .9, .2, .995, new EpsilonGreedyStrategy(.3), mdp);
fit = new FixedIterationTrainer(sl, 100);
startTime = System.currentTimeMillis();
fit.train();
p = sl.getPolicy();
finishTime = System.currentTimeMillis();
System.out.println("Sarsa lambda learned : " + p);
System.out.println("in " + 100 + " iterations");
System.out.println("and " + (finishTime - startTime) + " ms");
System.out.println("Acquiring " + sl.getTotalReward() + " reward");
}
}