import AgentProvider.Implementation.AgentProviderComponent; import AgentProvider.Interface.IAgentProvider; import AgentSystemPluginAPI.Contract.IStateActionGenerator; import AgentSystemPluginAPI.Contract.StateAction; import AgentSystemPluginAPI.Services.IAgent; import AgentSystemPluginAPI.Services.LearningAlgorithm; import EnvironmentPluginAPI.Exceptions.TechnicalException; import ZeroTypes.Settings.SettingException; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; /** * This test is supposed to determine if the concrete implementations of the learning algorithm show expected behaviour. */ public class TestReinforcementAlgorithms implements IStateActionGenerator { private static IAgentProvider agentProvider; private static int[][] cliffWorld; private static int width = 10; private static int height = 5; private int agentX = 0; private int agentY = 0; @BeforeClass public static void setup() { try { agentProvider = new AgentProviderComponent(); } catch (TechnicalException e) { e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. } try { agentProvider.loadAgentSystem("testAgentSystem"); } catch (TechnicalException e) { e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. } catch (SettingException e) { e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. } initializeCliff(); } public static void initializeCliff() { cliffWorld = new int[height][width]; //set way costs to -1 for all fields except the cliff for (int x = 0; x < height; x++) { for (int y = 0; y < width; y++) { if (x < height - 1 || y == 0 || y == width - 1) { //the normal fields cliffWorld[x][y] = -1; } else { //the cliff cliffWorld[x][y] = -100; } } } cliffWorld[height - 1][width - 1] = 200; } public void printState() { for (int x = 0; x < height; x++) { for (int y = 0; y < width; y++) { if (x == agentX && y == agentY) { System.out.print("A"); } else { System.out.print(cliffWorld[x][y]); } } System.out.print("\n"); } System.out.println("=================================================================================================\n"); } private void makeMove(StateAction action) { if (action.getActionDescription().equals("U")) { agentX--; } else if (action.getActionDescription().equals("D")) { agentX++; } else if (action.getActionDescription().equals("L")) { agentY--; } else if (action.getActionDescription().equals("R")) { agentY++; } } private int average(List<Integer> list) { int sum = 0; if (list.size() == 0) { return Integer.MAX_VALUE; } for (Integer i : list) { sum += i; } return sum / list.size(); } @Override public Set<StateAction> getAllPossibleActions(StateAction state) { Set<StateAction> result = new HashSet<StateAction>(); if (agentX > 1) { result.add(new StateAction(state.getStateDescription(), "U")); } if (agentX < height - 1) { result.add(new StateAction(state.getStateDescription(), "D")); } if (agentY > 1) { result.add(new StateAction(state.getStateDescription(), "L")); } if (agentY < width - 1) { result.add(new StateAction(state.getStateDescription(), "R")); } return result; } @Test public void testQLearning() throws TechnicalException, SettingException { IAgent agent = agentProvider.getTableAgent("QLearningGridWorld", LearningAlgorithm.QLearning, this); agent.setAlpha(0.6f); agent.setEpsilon(0.1f); agent.setGamma(0.7f); List<Integer> steps = new LinkedList<Integer>(); StateAction action; int step; for (int i = 0; i < 200; i++) { agentX = height - 1; agentY = 0; step = 0; action = agent.startEpisode(new StateAction("" + agentX + agentY)); while (agentX != height - 1 || agentY != width - 1) { //printState(); makeMove(action); action = agent.step(cliffWorld[agentX][agentY], new StateAction("" + agentX + agentY)); step++; } if (i >= 80) { steps.add(step); } //System.out.println("steps in this round: " + step + " average: " + average(steps)); agent.endEpisode(new StateAction("" + agentX + agentY), -1); } Assert.assertTrue(average(steps) < 15); } @Test public void testSARSALambda() throws TechnicalException, SettingException { IAgent agent = agentProvider.getTableAgent("SARSALambdaGridworld", LearningAlgorithm.SARSALambda, this); agent.setAlpha(0.6f); agent.setEpsilon(0.001f); agent.setGamma(0.7f); agent.setLambda(0.5f); List<Integer> steps = new LinkedList<Integer>(); StateAction action; int step; for (int i = 0; i < 100; i++) { agentX = height - 1; agentY = 0; step = 0; action = agent.startEpisode(new StateAction("" + agentX + agentY)); while (agentX != height - 1 || agentY != width - 1) { //printState(); makeMove(action); action = agent.step(cliffWorld[agentX][agentY], new StateAction("" + agentX + agentY)); step++; } if (i >= 80) { steps.add(step); //System.out.println( i + "th round, steps in this round: " + step + " average: " + average(steps)); } agent.endEpisode(new StateAction("" + agentX + agentY), -1); } Assert.assertTrue(average(steps) < 15); } @Test public void testSARSA() throws TechnicalException, SettingException { IAgent agent = agentProvider.getTableAgent("SARSAGridworld", LearningAlgorithm.SARSA, this); agent.setAlpha(0.6f); agent.setEpsilon(0.1f); agent.setGamma(0.7f); List<Integer> steps = new LinkedList<Integer>(); StateAction action; int step; for (int i = 0; i < 250; i++) { agentX = height - 1; agentY = 0; step = 0; action = agent.startEpisode(new StateAction("" + agentX + agentY)); while (agentX != height - 1 || agentY != width - 1) { //printState(); makeMove(action); action = agent.step(cliffWorld[agentX][agentY], new StateAction("" + agentX + agentY)); step++; } if (i >= 80) { steps.add(step); } //System.out.println("steps in this round: " + step + " average: " + average(steps)); agent.endEpisode(new StateAction("" + agentX + agentY), -1); } Assert.assertTrue(average(steps) < 20); } }