package rl; /** * A policy learner that learns policies through policy iteration * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class PolicyIteration implements PolicyLearner { /** * The tolerance for changes */ private static final double TOLERANCE = 1E-6; /** * The policy */ private Policy policy; /** * The process */ private MarkovDecisionProcess process; /** * The decay value */ private double gamma; /** * Make a new value iteration * @param gamma the gamma decay value */ public PolicyIteration(double gamma, MarkovDecisionProcess process) { this.gamma = gamma; this.process = process; policy = new Policy(process.getStateCount(), process.getActionCount()); } /** * @see shared.Trainer#train() */ public double train() { int stateCount = process.getStateCount(); int actionCount = process.getActionCount(); // perform value iteration with the policy double[] values = new double[stateCount]; boolean valuesChanged = false; do { valuesChanged = false; // loop through all the states for (int i = 0; i < stateCount; i++) { // calculate the new value int action = policy.getAction(i); double actionVal = 0; for (int j = 0; j < stateCount; j++) { actionVal += process.transitionProbability(i, j, action) * values[j]; } // val = reward + decay * expected value double val = process.reward(i, action) + gamma * actionVal; // check if we're done if (Math.abs(values[i] - val) > TOLERANCE) { valuesChanged = true; } values[i] = val; } } while (valuesChanged); int changed = 0; // calculate the new policy for (int i = 0; i < stateCount; i++) { // find the maximum action double maxActionVal = -Double.MAX_VALUE; int maxAction = 0; for (int a = 0; a < actionCount; a++) { double actionVal = 0; for (int j = 0; j < stateCount; j++) { actionVal += process.transitionProbability(i, j, a) * values[j]; } actionVal = process.reward(i, a) + gamma * actionVal; if (actionVal > maxActionVal) { maxActionVal = actionVal; maxAction = a; } } if (policy.getAction(i) != maxAction) { changed++; } policy.setAction(i, maxAction); } return changed; } /** * @see rl.PolicyLearner#getPolicy() */ public Policy getPolicy() { return policy; } }