package aima.core.learning.reinforcement.agent;
import java.util.HashMap;
import java.util.Map;
import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 844.<br>
* <br>
*
* <pre>
* function Q-LEARNING-AGENT(percept) returns an action
* inputs: percept, a percept indicating the current state s' and reward signal r'
* persistent: Q, a table of action values indexed by state and action, initially zero
* N<sub>sa</sub>, a table of frequencies for state-action pairs, initially zero
* s,a,r, the previous state, action, and reward, initially null
*
* if TERMAINAL?(s) then Q[s,None] <- r'
* if s is not null then
* increment N<sub>sa</sub>[s,a]
* Q[s,a] <- Q[s,a] + α(N<sub>sa</sub>[s,a])(r + γmax<sub>a'</sub>Q[s',a'] - Q[s,a])
* s,a,r <- s',argmax<sub>a'</sub>f(Q[s',a'],N<sub>sa</sub>[s',a']),r'
* return a
* </pre>
*
* Figure 21.8 An exploratory Q-learning agent. It is an active learner that
* learns the value Q(s,a) of each action in each situation. It uses the same
* exploration function f as the exploratory ADP agent, but avoids having to
* learn the transition model because the Q-value of a state can be related
* directly to those of its neighbors.<br>
* <br>
* <b>Note:</b> There appears to be two minor defects in the algorithm outlined
* in the book:<br>
* if TERMAINAL?(s) then Q[s,None] <- r'<br>
* should be:<br>
* if TERMAINAL?(s') then Q[s',None] <- r'<br>
* so that the correct value for Q[s',a'] is used in the Q[s,a] update rule when
* a terminal state is reached.<br>
* <br>
* s,a,r <- s',argmax<sub>a'</sub>f(Q[s',a'],N<sub>sa</sub>[s',a']),r'<br>
* should be:
*
* <pre>
* if s'.TERMINAL? then s,a,r <- null else s,a,r <- s',argmax<sub>a'</sub>f(Q[s',a'],N<sub>sa</sub>[s',a']),r'
* </pre>
*
* otherwise at the beginning of a consecutive trial, s will be the prior
* terminal state and is what will be updated in Q[s,a], which appears not to be
* correct as you did not perform an action in the terminal state and the
* initial state is not reachable from the prior terminal state. Comments
* welcome.
*
* @param <S>
* the state type.
* @param <A>
* the action type.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*
*/
public class QLearningAgent<S, A extends Action> extends
ReinforcementAgent<S, A> {
// persistent: Q, a table of action values indexed by state and action,
// initially zero
Map<Pair<S, A>, Double> Q = new HashMap<Pair<S, A>, Double>();
// N<sub>sa</sub>, a table of frequencies for state-action pairs, initially
// zero
private FrequencyCounter<Pair<S, A>> Nsa = new FrequencyCounter<Pair<S, A>>();
// s,a,r, the previous state, action, and reward, initially null
private S s = null;
private A a = null;
private Double r = null;
//
private ActionsFunction<S, A> actionsFunction = null;
private A noneAction = null;
private double alpha = 0.0;
private double gamma = 0.0;
private int Ne = 0;
private double Rplus = 0.0;
/**
* Constructor.
*
* @param actionsFunction
* a function that lists the legal actions from a state.
* @param noneAction
* an action representing None, i.e. a NoOp.
* @param alpha
* a fixed learning rate.
* @param gamma
* discount to be used.
* @param Ne
* is fixed parameter for use in the method f(u, n).
* @param Rplus
* R+ is an optimistic estimate of the best possible reward
* obtainable in any state, which is used in the method f(u, n).
*/
public QLearningAgent(ActionsFunction<S, A> actionsFunction,
A noneAction, double alpha,
double gamma, int Ne, double Rplus) {
this.actionsFunction = actionsFunction;
this.noneAction = noneAction;
this.alpha = alpha;
this.gamma = gamma;
this.Ne = Ne;
this.Rplus = Rplus;
}
/**
* An exploratory Q-learning agent. It is an active learner that learns the
* value Q(s,a) of each action in each situation. It uses the same
* exploration function f as the exploratory ADP agent, but avoids having to
* learn the transition model because the Q-value of a state can be related
* directly to those of its neighbors.
*
* @param percept
* a percept indicating the current state s' and reward signal
* r'.
* @return an action
*/
@Override
public A execute(PerceptStateReward<S> percept) {
S sPrime = percept.state();
double rPrime = percept.reward();
// if TERMAINAL?(s') then Q[s',None] <- r'
if (isTerminal(sPrime)) {
Q.put(new Pair<S, A>(sPrime, noneAction), rPrime);
}
// if s is not null then
if (null != s) {
// increment N<sub>sa</sub>[s,a]
Pair<S, A> sa = new Pair<S, A>(s, a);
Nsa.incrementFor(sa);
// Q[s,a] <- Q[s,a] + α(N<sub>sa</sub>[s,a])(r +
// γmax<sub>a'</sub>Q[s',a'] - Q[s,a])
Double Q_sa = Q.get(sa);
if (null == Q_sa) {
Q_sa = 0.0;
}
Q.put(sa, Q_sa + alpha(Nsa, s, a)
* (r + gamma * maxAPrime(sPrime) - Q_sa));
}
// if s'.TERMINAL? then s,a,r <- null else
// s,a,r <- s',argmax<sub>a'</sub>f(Q[s',a'],N<sub>sa</sub>[s',a']),r'
if (isTerminal(sPrime)) {
s = null;
a = null;
r = null;
} else {
s = sPrime;
a = argmaxAPrime(sPrime);
r = rPrime;
}
// return a
return a;
}
@Override
public void reset() {
Q.clear();
Nsa.clear();
s = null;
a = null;
r = null;
}
@Override
public Map<S, Double> getUtility() {
// Q-values are directly related to utility values as follows
// (AIMA3e pg. 843 - 21.6) :
// U(s) = max<sub>a</sub>Q(s,a).
Map<S, Double> U = new HashMap<S, Double>();
for (Pair<S, A> sa : Q.keySet()) {
Double q = Q.get(sa);
Double u = U.get(sa.getFirst());
if (null == u || u < q) {
U.put(sa.getFirst(), q);
}
}
return U;
}
//
// PROTECTED METHODS
//
/**
* AIMA3e pg. 836 'if we change α from a fixed parameter to a function
* that decreases as the number of times a state action has been observed
* increases, then U<sup>π</sup>(s) itself will converge to the correct
* value.<br>
* <br>
* <b>Note:</b> override this method to obtain the desired behavior.
*
* @param Nsa
* a frequency counter of observed state action pairs.
* @param s
* the current state.
* @param a the current action.
* @return the learning rate to use based on the frequency of the state
* passed in.
*/
protected double alpha(FrequencyCounter<Pair<S, A>> Nsa, S s, A a) {
// Default implementation is just to return a fixed parameter value
// irrespective of the # of times a state action has been encountered
return alpha;
}
/**
* AIMA3e pg. 842 'f(u, n) is called the <b>exploration function</b>. It
* determines how greed (preferences for high values of u) is traded off
* against curiosity (preferences for actions that have not been tried often
* and have low n). The function f(u, n) should be increasing in u and
* decreasing in n.
*
*
* <b>Note:</b> Override this method to obtain desired behavior.
*
* @param u
* the currently estimated utility.
* @param n
* the number of times this situation has been encountered.
* @return the exploration value.
*/
protected double f(Double u, int n) {
// A Simple definition of f(u, n):
if (null == u || n < Ne) {
return Rplus;
}
return u;
}
//
// PRIVATE METHODS
//
private boolean isTerminal(S s) {
boolean terminal = false;
if (null != s && actionsFunction.actions(s).size() == 0) {
// No actions possible in state is considered terminal.
terminal = true;
}
return terminal;
}
private double maxAPrime(S sPrime) {
double max = Double.NEGATIVE_INFINITY;
if (actionsFunction.actions(sPrime).size() == 0) {
// a terminal state
max = Q.get(new Pair<S, A>(sPrime, noneAction));
} else {
for (A aPrime : actionsFunction.actions(sPrime)) {
Double Q_sPrimeAPrime = Q.get(new Pair<S, A>(sPrime, aPrime));
if (null != Q_sPrimeAPrime && Q_sPrimeAPrime > max) {
max = Q_sPrimeAPrime;
}
}
}
if (max == Double.NEGATIVE_INFINITY) {
// Assign 0 as the mimics Q being initialized to 0 up front.
max = 0.0;
}
return max;
}
// argmax<sub>a'</sub>f(Q[s',a'],N<sub>sa</sub>[s',a'])
private A argmaxAPrime(S sPrime) {
A a = null;
double max = Double.NEGATIVE_INFINITY;
for (A aPrime : actionsFunction.actions(sPrime)) {
Pair<S, A> sPrimeAPrime = new Pair<S, A>(sPrime, aPrime);
double explorationValue = f(Q.get(sPrimeAPrime), Nsa
.getCount(sPrimeAPrime));
if (explorationValue > max) {
max = explorationValue;
a = aPrime;
}
}
return a;
}
}