package rl.test; import rl.EpsilonGreedyStrategy; import rl.MazeMarkovDecisionProcess; import rl.MazeMarkovDecisionProcessVisualization; import rl.Policy; import rl.PolicyIteration; import rl.QLambda; import rl.SarsaLambda; import rl.ValueIteration; import shared.FixedIterationTrainer; import shared.ThresholdTrainer; /** * Tests out the maze markov decision process classes * @author guillory * @version 1.0 */ public class MazeMDPTest { /** * Tests out things * @param args ignored * @throws Exception */ public static void main(String[] args) throws Exception { MazeMarkovDecisionProcess maze = MazeMarkovDecisionProcess.load("testmaze.txt"); System.out.println(maze); ValueIteration vi = new ValueIteration(.95, maze); ThresholdTrainer tt = new ThresholdTrainer(vi); long startTime = System.currentTimeMillis(); tt.train(); Policy p = vi.getPolicy(); long finishTime = System.currentTimeMillis(); System.out.println("Value iteration learned : " + p); System.out.println("in " + tt.getIterations() + " iterations"); System.out.println("and " + (finishTime - startTime) + " ms"); MazeMarkovDecisionProcessVisualization mazeVis = new MazeMarkovDecisionProcessVisualization(maze); System.out.println(mazeVis.toString(p)); PolicyIteration pi = new PolicyIteration(.95, maze); tt = new ThresholdTrainer(pi); startTime = System.currentTimeMillis(); tt.train(); p = pi.getPolicy(); finishTime = System.currentTimeMillis(); System.out.println("Policy iteration learned : " + p); System.out.println("in " + tt.getIterations() + " iterations"); System.out.println("and " + (finishTime - startTime) + " ms"); System.out.println(mazeVis.toString(p)); int iterations = 50000; QLambda ql = new QLambda(.5, .95, .2, 1, new EpsilonGreedyStrategy(.3), maze); FixedIterationTrainer fit = new FixedIterationTrainer(ql, iterations); startTime = System.currentTimeMillis(); fit.train(); p = ql.getPolicy(); finishTime = System.currentTimeMillis(); System.out.println("Q lambda learned : " + p); System.out.println("in " + iterations + " iterations"); System.out.println("and " + (finishTime - startTime) + " ms"); System.out.println("Acquiring " + ql.getTotalReward() + " reward"); System.out.println(mazeVis.toString(p)); SarsaLambda sl = new SarsaLambda(.5, .95, .2, 1, new EpsilonGreedyStrategy(.3), maze); fit = new FixedIterationTrainer(sl, iterations); startTime = System.currentTimeMillis(); fit.train(); p = sl.getPolicy(); finishTime = System.currentTimeMillis(); System.out.println("Sarsa lambda learned : " + p); System.out.println("in " + iterations + " iterations"); System.out.println("and " + (finishTime - startTime) + " ms"); System.out.println("Acquiring " + sl.getTotalReward() + " reward"); System.out.println(mazeVis.toString(p)); } }