package rl;
/**
* A policy learner that learns policies through value iteration
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class ValueIteration implements PolicyLearner {
/**
* The decay value
*/
private double gamma;
/**
* The process
*/
private MarkovDecisionProcess process;
/**
* The values
*/
private double[] values;
/**
* Make a new value iteration
* @param gamma the gamma decay value
*/
public ValueIteration(double gamma, MarkovDecisionProcess process) {
this.gamma = gamma;
this.process = process;
// the values
values = new double[process.getStateCount()];
for (int i = 0; i < process.getStateCount(); i++) {
double maxActionVal = process.reward(i, 0);
for (int a = 1; a < process.getActionCount(); a++) {
maxActionVal = Math.max(maxActionVal, process.reward(i, a));
}
values[i] = maxActionVal;
}
}
/**
* @see shared.Trainer#train()
*/
public double train() {
int stateCount = process.getStateCount();
int actionCount = process.getActionCount();
double difference = 0;
// loop through all the states
for (int i = 0; i < stateCount; i++) {
// find the maximum action
double maxActionVal = -Double.MAX_VALUE;
int maxAction = 0;
for (int a = 0; a < actionCount; a++) {
double actionVal = 0;
for (int j = 0; j < stateCount; j++) {
actionVal += process.transitionProbability(i, j, a) * values[j];
}
actionVal = process.reward(i, a) + gamma * actionVal;
if (actionVal > maxActionVal) {
maxActionVal = actionVal;
maxAction = a;
}
}
// check if we're done
difference = Math.max(Math.abs(values[i] - maxActionVal), difference);
values[i] = maxActionVal;
}
return difference;
}
/**
* @see rl.PolicyLearner#getPolicy()
*/
public Policy getPolicy() {
int stateCount = process.getStateCount();
int actionCount = process.getActionCount();
// calculate the policy based on the values
int[] policy = new int[stateCount];
for (int i = 0; i < stateCount; i++) {
// find the maximum action
double maxActionVal = -Double.MAX_VALUE;
int maxAction = 0;
for (int a = 0; a < actionCount; a++) {
double actionVal = 0;
for (int j = 0; j < stateCount; j++) {
actionVal += process.transitionProbability(i, j, a)
* values[j];
}
actionVal = process.reward(i, a) + gamma * actionVal;
if (actionVal > maxActionVal) {
maxActionVal = actionVal;
maxAction = a;
}
}
policy[i] = maxAction;
}
return new Policy(policy);
}
}