package aima.test.core.unit.learning.reinforcement.agent; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.junit.Assert; import aima.core.environment.cellworld.Cell; import aima.core.environment.cellworld.CellWorld; import aima.core.environment.cellworld.CellWorldAction; import aima.core.environment.cellworld.CellWorldFactory; import aima.core.learning.reinforcement.agent.ReinforcementAgent; import aima.core.learning.reinforcement.example.CellWorldEnvironment; import aima.core.probability.example.MDPFactory; import aima.core.util.JavaRandomizer; public abstract class ReinforcementLearningAgentTest { public static void test_RMSeiu_for_1_1( ReinforcementAgent<Cell<Double>, CellWorldAction> reinforcementAgent, int numRuns, int numTrialsPerRun, double expectedErrorLessThan) { CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); cwe.addAgent(reinforcementAgent); Map<Integer, Map<Cell<Double>, Double>> runs = new HashMap<Integer, Map<Cell<Double>, Double>>(); for (int r = 0; r < numRuns; r++) { reinforcementAgent.reset(); cwe.executeTrials(numTrialsPerRun); runs.put(r, reinforcementAgent.getUtility()); } // Calculate the Root Mean Square Error for utility of 1,1 // for this trial# across all runs double xSsquared = 0; for (int r = 0; r < numRuns; r++) { Map<Cell<Double>, Double> u = runs.get(r); Double val1_1 = u.get(cw.getCellAt(1, 1)); if (null == val1_1) { throw new IllegalStateException("U(1,1,) is not present: r="+r+", u="+u); } xSsquared += Math.pow(0.705 - val1_1, 2); } double rmse = Math.sqrt(xSsquared / runs.size()); Assert.assertTrue(""+rmse+" is not < "+expectedErrorLessThan, rmse < expectedErrorLessThan); } public static void test_utility_learning_rates( ReinforcementAgent<Cell<Double>, CellWorldAction> reinforcementAgent, int numRuns, int numTrialsPerRun, int rmseTrialsToReport, int reportEveryN) { if (rmseTrialsToReport > (numTrialsPerRun / reportEveryN)) { throw new IllegalArgumentException( "Requesting to report too many RMSE trials, max allowed for args is " + (numTrialsPerRun / reportEveryN)); } CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); cwe.addAgent(reinforcementAgent); Map<Integer, List<Map<Cell<Double>, Double>>> runs = new HashMap<Integer, List<Map<Cell<Double>, Double>>>(); for (int r = 0; r < numRuns; r++) { reinforcementAgent.reset(); List<Map<Cell<Double>, Double>> trials = new ArrayList<Map<Cell<Double>, Double>>(); for (int t = 0; t < numTrialsPerRun; t++) { cwe.executeTrial(); if (0 == t % reportEveryN) { Map<Cell<Double>, Double> u = reinforcementAgent.getUtility(); if (null == u.get(cw.getCellAt(1, 1))) { throw new IllegalStateException("Bad Utility State Encountered: r="+r+", t="+t+", u="+u); } trials.add(u); } } runs.put(r, trials); } StringBuilder v4_3 = new StringBuilder(); StringBuilder v3_3 = new StringBuilder(); StringBuilder v1_3 = new StringBuilder(); StringBuilder v1_1 = new StringBuilder(); StringBuilder v3_2 = new StringBuilder(); StringBuilder v2_1 = new StringBuilder(); for (int t = 0; t < (numTrialsPerRun / reportEveryN); t++) { // Use the last run Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t); v4_3.append((u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw .getCellAt(4, 3)) : 0.0) + "\t"); v3_3.append((u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw .getCellAt(3, 3)) : 0.0) + "\t"); v1_3.append((u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw .getCellAt(1, 3)) : 0.0) + "\t"); v1_1.append((u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw .getCellAt(1, 1)) : 0.0) + "\t"); v3_2.append((u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw .getCellAt(3, 2)) : 0.0) + "\t"); v2_1.append((u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw .getCellAt(2, 1)) : 0.0) + "\t"); } System.out.println("(4,3)" + "\t" + v4_3); System.out.println("(3,3)" + "\t" + v3_3); System.out.println("(1,3)" + "\t" + v1_3); System.out.println("(1,1)" + "\t" + v1_1); System.out.println("(3,2)" + "\t" + v3_2); System.out.println("(2,1)" + "\t" + v2_1); StringBuilder rmseValues = new StringBuilder(); for (int t = 0; t < rmseTrialsToReport; t++) { // Calculate the Root Mean Square Error for utility of 1,1 // for this trial# across all runs double xSsquared = 0; for (int r = 0; r < numRuns; r++) { Map<Cell<Double>, Double> u = runs.get(r).get(t); Double val1_1 = u.get(cw.getCellAt(1, 1)); if (null == val1_1) { throw new IllegalStateException("U(1,1,) is not present: r="+r+", t="+t+", runs.size="+runs.size()+", runs(r).size()="+runs.get(r).size()+", u="+u); } xSsquared += Math.pow(0.705 - val1_1, 2); } double rmse = Math.sqrt(xSsquared / runs.size()); rmseValues.append(rmse); rmseValues.append("\t"); } System.out.println("RMSeiu" + "\t" + rmseValues); } }