package AgentProvider.Implementation.Agents; import AgentSystemPluginAPI.Contract.IStateActionGenerator; import AgentSystemPluginAPI.Contract.StateAction; import EnvironmentPluginAPI.Exceptions.TechnicalException; import ZeroTypes.Exceptions.ErrorMessages; import java.util.LinkedList; import java.util.List; public class SarsaLambdaAgent extends EpsilonGreedyAgent { private final IDictionary eValues; private List<StateAction> history; private StateAction sa; private StateAction s_a_; //caching float delta; float oldQ; float oldE; public SarsaLambdaAgent(String name, IDictionary qValues, IDictionary eValues, IStateActionGenerator stateActionGenerator, IAgentSettingUpdatedListener agentSettingUpdatedListener) { super(name, qValues, stateActionGenerator, agentSettingUpdatedListener); this.eValues = eValues; history = new LinkedList<StateAction>(); } @Override public StateAction getCurrentState() { return sa; } @Override public StateAction startEpisode(StateAction state) throws TechnicalException { sa = getEpsilonInfluencedAction(state); history.add(0, sa); return sa; } @Override public StateAction step(float rewardForLastStep, StateAction newState) throws TechnicalException { if(sa == null) { throw new RuntimeException(ErrorMessages.get("startStateNotInitialized", getName())); } s_a_ = getEpsilonInfluencedAction(newState); updateValues(rewardForLastStep, s_a_); sa = s_a_; history.add(0, sa); return sa; } private void updateValues(float reward, StateAction s_a_) throws TechnicalException { oldQ = qValues.getValue(sa); oldE = eValues.getValue(sa) + 1.0f; delta = reward + (getGamma() * qValues.getValue(s_a_)) - oldQ; for (int i = 0; i < history.size(); i++) { qValues.setValue(sa, oldQ + (getAlpha() * delta * oldE)); eValues.setValue(sa, getGamma() * getLambda() * oldE); } } @Override public void endEpisode(StateAction stateAction, float reward) throws TechnicalException { updateValues(reward, stateAction); } }