/*
* JABM - Java Agent-Based Modeling Toolkit
* Copyright (C) 2013 Steve Phelps
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 3 of
* the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*/
package net.sourceforge.jabm.strategy;
import java.io.Serializable;
import java.util.List;
import net.sourceforge.jabm.EventScheduler;
import net.sourceforge.jabm.agent.Agent;
import net.sourceforge.jabm.learning.MDPLearner;
import net.sourceforge.jabm.learning.StatelessQLearner;
import net.sourceforge.jabm.report.Taggable;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.beans.factory.annotation.Required;
public abstract class RlStrategyWithState extends AbstractRlStrategy
implements Serializable, InitializingBean, Taggable {
protected MDPLearner learner;
static Logger logger = Logger.getLogger(RlStrategyWithState.class);
public RlStrategyWithState(Agent agent,
ObjectFactory<Strategy> strategyFactory,
MDPLearner learner) {
super(agent);
this.learner = learner;
this.strategyFactory = strategyFactory;
initialise();
}
public RlStrategyWithState(ObjectFactory<Strategy> strategyFactory,
MDPLearner learner) {
this(null, strategyFactory, learner);
}
public RlStrategyWithState() {
}
@Override
public void subscribeToEvents(EventScheduler scheduler) {
super.subscribeToEvents(scheduler);
for(int i=0; i<actions.length; i++) {
actions[i].subscribeToEvents(scheduler);
}
// scheduler.addListener(SimulationFinishedEvent.class, this);
// scheduler.addListener(InteractionsFinishedEvent.class, this);
}
public void execute(List<Agent> otherAgents) {
assert this.agent != null;
double reward = agent.getPayoffDelta();
int state = getState();
learner.newState(reward, state);
int action = learner.act();
currentStrategy = actions[action];
assert currentStrategy.getAgent() != null;
currentStrategy.execute(otherAgents);
}
public MDPLearner getLearner() {
return learner;
}
@Required
public void setLearner(MDPLearner learner) {
this.learner = learner;
initialise();
}
@Override
public void setAgent(Agent agent) {
super.setAgent(agent);
for(int i=0; i<actions.length; i++) {
actions[i].setAgent(agent);
}
}
@Override
public Strategy clone() throws CloneNotSupportedException {
throw new CloneNotSupportedException();
}
//
// @Override
// public void eventOccurred(SimEvent event) {
// super.eventOccurred(event);
// if (event instanceof InteractionsFinishedEvent) {
// onInteractionsFinished();
// }
// }
// public void onInteractionsFinished() {
// double reward = agent.getPayoffDelta();
// learner.reward(reward);
// }
@Override
public void unsubscribeFromEvents() {
for(int i=0; i<actions.length; i++) {
actions[i].unsubscribeFromEvents();
}
super.unsubscribeFromEvents();
}
public ObjectFactory<Strategy> getStrategyFactory() {
return strategyFactory;
}
@Required
public void setStrategyFactory(ObjectFactory<Strategy> strategyFactory) {
this.strategyFactory = strategyFactory;
}
public double[] getInitialPropensities() {
return initialPropensities;
}
public void setInitialPropensities(double[] initialPropensities) {
StatelessQLearner qLearner = (StatelessQLearner) this.learner;
double[] propensities = qLearner.getqLearner().getValueEstimates(0);
for(int i=0; i<actions.length; i++) {
propensities[i] = initialPropensities[i];
}
}
@Override
public void afterPropertiesSet() throws Exception {
// initialise();
}
@Override
public String getTag() {
if (currentStrategy != null && currentStrategy instanceof Taggable) {
return "SRL: " + ((Taggable) currentStrategy).getTag();
} else {
return this.getClass().toString();
}
}
@Override
public void setTag(String tag) {
}
public int getNumberOfActions() {
return learner.getNumberOfActions();
}
public abstract int getState();
}