package aima.test.core.unit.probability.mdp;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import aima.core.environment.cellworld.Cell;
import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldAction;
import aima.core.environment.cellworld.CellWorldFactory;
import aima.core.probability.example.MDPFactory;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.Policy;
import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation;
import aima.core.probability.mdp.search.PolicyIteration;
/**
*
* @author Ravi Mohan
* @author Ciaran O'Reilly
*/
public class PolicyIterationTest {
private CellWorld<Double> cw = null;
private MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = null;
private PolicyIteration<Cell<Double>, CellWorldAction> pi = null;
@Before
public void setUp() {
cw = CellWorldFactory.createCellWorldForFig17_1();
mdp = MDPFactory.createMDPForFigure17_3(cw);
pi = new PolicyIteration<Cell<Double>, CellWorldAction>(
new ModifiedPolicyEvaluation<Cell<Double>, CellWorldAction>(50, 1.0));
}
@Test
public void testPolicyIterationForFig17_2() {
// AIMA3e check with Figure 17.2 (a)
Policy<Cell<Double>, CellWorldAction> policy = pi.policyIteration(mdp);
Assert.assertEquals(CellWorldAction.Up,
policy.action(cw.getCellAt(1, 1)));
Assert.assertEquals(CellWorldAction.Up,
policy.action(cw.getCellAt(1, 2)));
Assert.assertEquals(CellWorldAction.Right,
policy.action(cw.getCellAt(1, 3)));
Assert.assertEquals(CellWorldAction.Left,
policy.action(cw.getCellAt(2, 1)));
Assert.assertEquals(CellWorldAction.Right,
policy.action(cw.getCellAt(2, 3)));
Assert.assertEquals(CellWorldAction.Left,
policy.action(cw.getCellAt(3, 1)));
Assert.assertEquals(CellWorldAction.Up,
policy.action(cw.getCellAt(3, 2)));
Assert.assertEquals(CellWorldAction.Right,
policy.action(cw.getCellAt(3, 3)));
Assert.assertEquals(CellWorldAction.Left,
policy.action(cw.getCellAt(4, 1)));
Assert.assertNull(policy.action(cw.getCellAt(4, 2)));
Assert.assertNull(policy.action(cw.getCellAt(4, 3)));
}
}