package AgentProvider.Implementation.Agents; import AgentSystemPluginAPI.Contract.IStateActionGenerator; import AgentSystemPluginAPI.Contract.StateAction; import EnvironmentPluginAPI.Exceptions.TechnicalException; import ZeroTypes.Exceptions.ErrorMessages; public class SarsaAgent extends EpsilonGreedyAgent { private StateAction sa; private StateAction s_a_; public SarsaAgent(String name, IDictionary qValues, IStateActionGenerator stateActionGenerator, IAgentSettingUpdatedListener agentSettingUpdatedListener) { super(name, qValues, stateActionGenerator, agentSettingUpdatedListener); } @Override public StateAction startEpisode(StateAction state) throws TechnicalException { sa = getEpsilonInfluencedAction(state); return sa; } @Override public StateAction getCurrentState() { return sa; } @Override public StateAction step(float rewardForLastStep, StateAction newState) throws TechnicalException { if(sa == null) { throw new RuntimeException(ErrorMessages.get("startStateNotInitialized", getName())); } s_a_ = getEpsilonInfluencedAction(newState); updateQ(s_a_, rewardForLastStep); sa = s_a_; return sa; } private void updateQ(StateAction s_a_, float reward) throws TechnicalException { qValues.setValue(sa, qValues.getValue(sa) + (getAlpha() * (reward + (getGamma()* qValues.getValue(s_a_)) - qValues.getValue(sa)))); } @Override public void endEpisode(StateAction stateAction, float reward) throws TechnicalException { updateQ(stateAction, reward); } }