package mdpsSolution;
import java.util.Random;
//
import mdps.MDP;
import mdps.QValueFunction;
public class QLearning {
TableLookupQValueFunction valFunc;
private MDP mdp;
private double epsilon = 0.3;
private double alpha = 0.2;
private double gamma = 0.9;
Random rand = new Random();
QLearning(MDP mdp) {
this.mdp = mdp;
valFunc = new TableLookupQValueFunction(mdp, 0);
}
public int greedyAction(int state) {
double maxVal = Double.NEGATIVE_INFINITY;
int maxAct = 0;
for (int i = 0; i < mdp.numActions(); i++) {
double val = valFunc.getValue(state, i);
if (val > maxVal) {
maxVal = val;
maxAct = i;
}
}
return maxAct;
}
public int nextState(int state, int action) {
double[] stateDistr = mdp.nextStateDistribution(state, action);
double choiceIndex = rand.nextDouble();
double sum = 0;
for (int i = 0; i < stateDistr.length; i++) {
sum += stateDistr[i];
if (sum >= choiceIndex)
return i;
}
// code should never reach this point
return -1;
}
public void updateValues(int state, int action, int nextState) {
double reward = mdp.getReward(state, action);
// end the episode
if (mdp.isTerminalState(state)) {
double currentValue = valFunc.getValue(state, action);
double newValue = (alpha)*currentValue + (1-alpha)*(reward);
valFunc.updateValue(state, action, newValue);
}else {
double nextValue = valFunc.getValue(nextState);
double currentValue = valFunc.getValue(state, action);
double newValue = (alpha)*currentValue + (1-alpha)*(reward + gamma*nextValue);
valFunc.updateValue(state, action, newValue);
}
}
public int espilonGreedyStep(int currentState) {
int action = greedyAction(currentState);
double choiceIndex = rand.nextDouble();
if (epsilon > choiceIndex)
action = rand.nextInt(mdp.numActions());
if (mdp.isTerminalState(currentState)) {
updateValues(currentState, action, -1);
return rand.nextInt(mdp.numStates());
} else {
int newState = nextState(currentState, action);
updateValues(currentState, action, newState);
return newState;
}
}
public QValueFunction getValueFunctionReference() {
return valFunc;
}
public MDP getMDP() {
return mdp;
}
}