import Actions.ActionDescription; import Actions.Direction; import Actions.EnvironmentState; import AgentSystemPluginAPI.Contract.IAgentSystem; import AgentSystemPluginAPI.Contract.IStateActionGenerator; import AgentSystemPluginAPI.Contract.StateAction; import AgentSystemPluginAPI.Services.IAgent; import AgentSystemPluginAPI.Services.IPluginServiceProvider; import AgentSystemPluginAPI.Services.LearningAlgorithm; import EnvironmentPluginAPI.Exceptions.TechnicalException; import Logic.GridWorldConfiguration; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; /** * A simple agent system for the a grid world, consisting of one q-learning agent. */ public class CliffQLearningAgentSystem implements IAgentSystem<GridWorldConfiguration, EnvironmentState, ActionDescription> { private final IPluginServiceProvider pluginServiceProvider; IAgent sarsaLambdaAgent; private int gridWidth; private int gridHeight; public CliffQLearningAgentSystem(IPluginServiceProvider pluginServiceProvider) throws TechnicalException { this.pluginServiceProvider = pluginServiceProvider; } @Override public void start(GridWorldConfiguration configuration) throws TechnicalException { gridWidth = configuration.getWidth(); gridHeight = configuration.getHeight(); if (sarsaLambdaAgent == null) { sarsaLambdaAgent = pluginServiceProvider.getTableAgent("qlearning", LearningAlgorithm.SARSALambda, new IStateActionGenerator() { @Override public Set<StateAction> getAllPossibleActions(StateAction stateAction) { DataInputStream in = new DataInputStream(new ByteArrayInputStream(stateAction.getStateDescription().getBytes())); try { int x = in.read(); int y = in.read(); Set<StateAction> possibleActions = new HashSet<StateAction>(); if (x < gridWidth - 1) possibleActions.add(new StateAction(stateAction.getStateDescription(), "RIGHT")); if (x > 0) possibleActions.add(new StateAction(stateAction.getStateDescription(), "LEFT")); if (y > 0) possibleActions.add(new StateAction(stateAction.getStateDescription(), "DOWN")); if (y < gridHeight - 1) possibleActions.add(new StateAction(stateAction.getStateDescription(), "UP")); return possibleActions; } catch (IOException e) { e.printStackTrace(); return new HashSet<StateAction>(); } } }); // set learning parameters sarsaLambdaAgent.setAlpha(0.6f); sarsaLambdaAgent.setEpsilon(0.001f); sarsaLambdaAgent.setGamma(0.7f); sarsaLambdaAgent.setLambda(0.5f); } sarsaLambdaAgent.startEpisode(new StateAction(new EnvironmentState(0, 0, false, 0.0f).getCompressedRepresentation())); } @Override public ActionDescription getActionsForEnvironmentStatus(EnvironmentState environmentState) throws TechnicalException { System.err.println("environmentState erhalten: "); StateAction newStateAction = sarsaLambdaAgent.step(environmentState.getReward(), new StateAction(environmentState.getCompressedRepresentation())); return new ActionDescription(Direction.valueOf(newStateAction.getActionDescription())); } @Override public void end() throws TechnicalException { System.err.println("beendet"); } }