package aima.gui.demo.probability; import java.util.ArrayList; import java.util.List; import java.util.Map; import aima.core.environment.cellworld.Cell; import aima.core.environment.cellworld.CellWorld; import aima.core.environment.cellworld.CellWorldAction; import aima.core.environment.cellworld.CellWorldFactory; import aima.core.probability.CategoricalDistribution; import aima.core.probability.FiniteProbabilityModel; import aima.core.probability.bayes.approx.BayesInferenceApproxAdapter; import aima.core.probability.bayes.approx.GibbsAsk; import aima.core.probability.bayes.approx.LikelihoodWeighting; import aima.core.probability.bayes.approx.ParticleFiltering; import aima.core.probability.bayes.approx.RejectionSampling; import aima.core.probability.bayes.exact.EliminationAsk; import aima.core.probability.bayes.exact.EnumerationAsk; import aima.core.probability.bayes.model.FiniteBayesModel; import aima.core.probability.example.BayesNetExampleFactory; import aima.core.probability.example.DynamicBayesNetExampleFactory; import aima.core.probability.example.ExampleRV; import aima.core.probability.example.FullJointDistributionBurglaryAlarmModel; import aima.core.probability.example.FullJointDistributionToothacheCavityCatchModel; import aima.core.probability.example.GenericTemporalModelFactory; import aima.core.probability.example.HMMExampleFactory; import aima.core.probability.example.MDPFactory; import aima.core.probability.hmm.exact.FixedLagSmoothing; import aima.core.probability.mdp.MarkovDecisionProcess; import aima.core.probability.mdp.Policy; import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation; import aima.core.probability.mdp.search.PolicyIteration; import aima.core.probability.mdp.search.ValueIteration; import aima.core.probability.proposition.AssignmentProposition; import aima.core.probability.proposition.DisjunctiveProposition; import aima.core.probability.temporal.generic.ForwardBackward; import aima.core.probability.util.ProbabilityTable; import aima.core.util.MockRandomizer; /** * @author Ravi Mohan * @author Ciaran O'Reilly */ public class ProbabilityDemo { // Note: You should increase this to 1000000+ // in order to get answers from the approximate // algorithms (i.e. Rejection, Likelihood and Gibbs) // that look close to their exact inference // counterparts. public static final int NUM_SAMPLES = 1000; public static void main(String[] args) { // Chapter 13 fullJointDistributionModelDemo(); // Chapter 14 - Exact bayesEnumerationAskDemo(); bayesEliminationAskDemo(); // Chapter 14 - Approx bayesRejectionSamplingDemo(); bayesLikelihoodWeightingDemo(); bayesGibbsAskDemo(); // Chapter 15 forwardBackWardDemo(); fixedLagSmoothingDemo(); particleFilterinfDemo(); // Chapter 17 valueIterationDemo(); policyIterationDemo(); } public static void fullJointDistributionModelDemo() { System.out.println("DEMO: Full Joint Distribution Model"); System.out.println("==================================="); demoToothacheCavityCatchModel(new FullJointDistributionToothacheCavityCatchModel()); demoBurglaryAlarmModel(new FullJointDistributionBurglaryAlarmModel()); System.out.println("==================================="); } public static void bayesEnumerationAskDemo() { System.out.println("DEMO: Bayes Enumeration Ask"); System.out.println("==========================="); demoToothacheCavityCatchModel(new FiniteBayesModel( BayesNetExampleFactory.constructToothacheCavityCatchNetwork(), new EnumerationAsk())); demoBurglaryAlarmModel(new FiniteBayesModel( BayesNetExampleFactory.constructBurglaryAlarmNetwork(), new EnumerationAsk())); System.out.println("==========================="); } public static void bayesEliminationAskDemo() { System.out.println("DEMO: Bayes Elimination Ask"); System.out.println("==========================="); demoToothacheCavityCatchModel(new FiniteBayesModel( BayesNetExampleFactory.constructToothacheCavityCatchNetwork(), new EliminationAsk())); demoBurglaryAlarmModel(new FiniteBayesModel( BayesNetExampleFactory.constructBurglaryAlarmNetwork(), new EliminationAsk())); System.out.println("==========================="); } public static void bayesRejectionSamplingDemo() { System.out.println("DEMO: Bayes Rejection Sampling N = " + NUM_SAMPLES); System.out.println("=============================="); demoToothacheCavityCatchModel(new FiniteBayesModel( BayesNetExampleFactory.constructToothacheCavityCatchNetwork(), new BayesInferenceApproxAdapter(new RejectionSampling(), NUM_SAMPLES))); demoBurglaryAlarmModel(new FiniteBayesModel( BayesNetExampleFactory.constructBurglaryAlarmNetwork(), new BayesInferenceApproxAdapter(new RejectionSampling(), NUM_SAMPLES))); System.out.println("=============================="); } public static void bayesLikelihoodWeightingDemo() { System.out.println("DEMO: Bayes Likelihood Weighting N = " + NUM_SAMPLES); System.out.println("================================"); demoToothacheCavityCatchModel(new FiniteBayesModel( BayesNetExampleFactory.constructToothacheCavityCatchNetwork(), new BayesInferenceApproxAdapter(new LikelihoodWeighting(), NUM_SAMPLES))); demoBurglaryAlarmModel(new FiniteBayesModel( BayesNetExampleFactory.constructBurglaryAlarmNetwork(), new BayesInferenceApproxAdapter(new LikelihoodWeighting(), NUM_SAMPLES))); System.out.println("================================"); } public static void bayesGibbsAskDemo() { System.out.println("DEMO: Bayes Gibbs Ask N = " + NUM_SAMPLES); System.out.println("====================="); demoToothacheCavityCatchModel(new FiniteBayesModel( BayesNetExampleFactory.constructToothacheCavityCatchNetwork(), new BayesInferenceApproxAdapter(new GibbsAsk(), NUM_SAMPLES))); demoBurglaryAlarmModel(new FiniteBayesModel( BayesNetExampleFactory.constructBurglaryAlarmNetwork(), new BayesInferenceApproxAdapter(new GibbsAsk(), NUM_SAMPLES))); System.out.println("====================="); } public static void forwardBackWardDemo() { System.out.println("DEMO: Forward-BackWard"); System.out.println("======================"); System.out.println("Umbrella World"); System.out.println("--------------"); ForwardBackward uw = new ForwardBackward( GenericTemporalModelFactory.getUmbrellaWorldTransitionModel(), GenericTemporalModelFactory.getUmbrellaWorld_Xt_to_Xtm1_Map(), GenericTemporalModelFactory.getUmbrellaWorldSensorModel()); CategoricalDistribution prior = new ProbabilityTable(new double[] { 0.5, 0.5 }, ExampleRV.RAIN_t_RV); // Day 1 List<List<AssignmentProposition>> evidence = new ArrayList<List<AssignmentProposition>>(); List<AssignmentProposition> e1 = new ArrayList<AssignmentProposition>(); e1.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); evidence.add(e1); List<CategoricalDistribution> smoothed = uw.forwardBackward(evidence, prior); System.out.println("Day 1 (Umbrealla_t=true) smoothed:\nday 1 = " + smoothed.get(0)); // Day 2 List<AssignmentProposition> e2 = new ArrayList<AssignmentProposition>(); e2.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); evidence.add(e2); smoothed = uw.forwardBackward(evidence, prior); System.out.println("Day 2 (Umbrealla_t=true) smoothed:\nday 1 = " + smoothed.get(0) + "\nday 2 = " + smoothed.get(1)); // Day 3 List<AssignmentProposition> e3 = new ArrayList<AssignmentProposition>(); e3.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.FALSE)); evidence.add(e3); smoothed = uw.forwardBackward(evidence, prior); System.out.println("Day 3 (Umbrealla_t=false) smoothed:\nday 1 = " + smoothed.get(0) + "\nday 2 = " + smoothed.get(1) + "\nday 3 = " + smoothed.get(2)); System.out.println("======================"); } public static void fixedLagSmoothingDemo() { System.out.println("DEMO: Fixed-Lag-Smoothing"); System.out.println("========================="); System.out.println("Lag = 1"); System.out.println("-------"); FixedLagSmoothing uw = new FixedLagSmoothing( HMMExampleFactory.getUmbrellaWorldModel(), 1); // Day 1 - Lag 1 List<AssignmentProposition> e1 = new ArrayList<AssignmentProposition>(); e1.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); CategoricalDistribution smoothed = uw.fixedLagSmoothing(e1); System.out.println("Day 1 (Umbrella_t=true) smoothed:\nday 1=" + smoothed); // Day 2 - Lag 1 List<AssignmentProposition> e2 = new ArrayList<AssignmentProposition>(); e2.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); smoothed = uw.fixedLagSmoothing(e2); System.out.println("Day 2 (Umbrella_t=true) smoothed:\nday 1=" + smoothed); // Day 3 - Lag 1 List<AssignmentProposition> e3 = new ArrayList<AssignmentProposition>(); e3.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.FALSE)); smoothed = uw.fixedLagSmoothing(e3); System.out.println("Day 3 (Umbrella_t=false) smoothed:\nday 2=" + smoothed); System.out.println("-------"); System.out.println("Lag = 2"); System.out.println("-------"); uw = new FixedLagSmoothing(HMMExampleFactory.getUmbrellaWorldModel(), 2); // Day 1 - Lag 2 e1 = new ArrayList<AssignmentProposition>(); e1.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); smoothed = uw.fixedLagSmoothing(e1); System.out.println("Day 1 (Umbrella_t=true) smoothed:\nday 1=" + smoothed); // Day 2 - Lag 2 e2 = new ArrayList<AssignmentProposition>(); e2.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.TRUE)); smoothed = uw.fixedLagSmoothing(e2); System.out.println("Day 2 (Umbrella_t=true) smoothed:\nday 1=" + smoothed); // Day 3 - Lag 2 e3 = new ArrayList<AssignmentProposition>(); e3.add(new AssignmentProposition(ExampleRV.UMBREALLA_t_RV, Boolean.FALSE)); smoothed = uw.fixedLagSmoothing(e3); System.out.println("Day 3 (Umbrella_t=false) smoothed:\nday 1=" + smoothed); System.out.println("========================="); } public static void particleFilterinfDemo() { System.out.println("DEMO: Particle-Filtering"); System.out.println("========================"); System.out.println("Figure 15.18"); System.out.println("------------"); MockRandomizer mr = new MockRandomizer(new double[] { // Prior Sample: // 8 with Rain_t-1=true from prior distribution 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // 2 with Rain_t-1=false from prior distribution 0.6, 0.6, // (a) Propagate 6 samples Rain_t=true 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, // 4 samples Rain_t=false 0.71, 0.71, 0.31, 0.31, // (b) Weight should be for first 6 samples: // Rain_t-1=true, Rain_t=true, Umbrella_t=false = 0.1 // Next 2 samples: // Rain_t-1=true, Rain_t=false, Umbrealla_t=false= 0.8 // Final 2 samples: // Rain_t-1=false, Rain_t=false, Umbrella_t=false = 0.8 // gives W[] = // [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8] // normalized = // [0.026, ...., 0.211, ....] is approx. 0.156 = true // the remainder is false // (c) Resample 2 Rain_t=true, 8 Rain_t=false 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, // // Next Sample: // (a) Propagate 1 samples Rain_t=true 0.7, // 9 samples Rain_t=false 0.71, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, // (c) resample 1 Rain_t=true, 9 Rain_t=false 0.0001, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2 }); int N = 10; ParticleFiltering pf = new ParticleFiltering(N, DynamicBayesNetExampleFactory.getUmbrellaWorldNetwork(), mr); AssignmentProposition[] e = new AssignmentProposition[] { new AssignmentProposition( ExampleRV.UMBREALLA_t_RV, false) }; System.out.println("First Sample Set:"); AssignmentProposition[][] S = pf.particleFiltering(e); for (int i = 0; i < N; i++) { System.out.println("Sample " + (i + 1) + " = " + S[i][0]); } System.out.println("Second Sample Set:"); S = pf.particleFiltering(e); for (int i = 0; i < N; i++) { System.out.println("Sample " + (i + 1) + " = " + S[i][0]); } System.out.println("========================"); } public static void valueIterationDemo() { System.out.println("DEMO: Value Iteration"); System.out.println("====================="); System.out.println("Figure 17.3"); System.out.println("-----------"); CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = MDPFactory .createMDPForFigure17_3(cw); ValueIteration<Cell<Double>, CellWorldAction> vi = new ValueIteration<Cell<Double>, CellWorldAction>( 1.0); Map<Cell<Double>, Double> U = vi.valueIteration(mdp, 0.0001); System.out.println("(1,1) = " + U.get(cw.getCellAt(1, 1))); System.out.println("(1,2) = " + U.get(cw.getCellAt(1, 2))); System.out.println("(1,3) = " + U.get(cw.getCellAt(1, 3))); System.out.println("(2,1) = " + U.get(cw.getCellAt(2, 1))); System.out.println("(2,3) = " + U.get(cw.getCellAt(2, 3))); System.out.println("(3,1) = " + U.get(cw.getCellAt(3, 1))); System.out.println("(3,2) = " + U.get(cw.getCellAt(3, 2))); System.out.println("(3,3) = " + U.get(cw.getCellAt(3, 3))); System.out.println("(4,1) = " + U.get(cw.getCellAt(4, 1))); System.out.println("(4,2) = " + U.get(cw.getCellAt(4, 2))); System.out.println("(4,3) = " + U.get(cw.getCellAt(4, 3))); System.out.println("========================="); } public static void policyIterationDemo() { System.out.println("DEMO: Policy Iteration"); System.out.println("======================"); System.out.println("Figure 17.3"); System.out.println("-----------"); CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); MarkovDecisionProcess<Cell<Double>, CellWorldAction> mdp = MDPFactory .createMDPForFigure17_3(cw); PolicyIteration<Cell<Double>, CellWorldAction> pi = new PolicyIteration<Cell<Double>, CellWorldAction>( new ModifiedPolicyEvaluation<Cell<Double>, CellWorldAction>(50, 1.0)); Policy<Cell<Double>, CellWorldAction> policy = pi.policyIteration(mdp); System.out.println("(1,1) = " + policy.action(cw.getCellAt(1, 1))); System.out.println("(1,2) = " + policy.action(cw.getCellAt(1, 2))); System.out.println("(1,3) = " + policy.action(cw.getCellAt(1, 3))); System.out.println("(2,1) = " + policy.action(cw.getCellAt(2, 1))); System.out.println("(2,3) = " + policy.action(cw.getCellAt(2, 3))); System.out.println("(3,1) = " + policy.action(cw.getCellAt(3, 1))); System.out.println("(3,2) = " + policy.action(cw.getCellAt(3, 2))); System.out.println("(3,3) = " + policy.action(cw.getCellAt(3, 3))); System.out.println("(4,1) = " + policy.action(cw.getCellAt(4, 1))); System.out.println("(4,2) = " + policy.action(cw.getCellAt(4, 2))); System.out.println("(4,3) = " + policy.action(cw.getCellAt(4, 3))); System.out.println("========================="); } // // PRIVATE METHODS // private static void demoToothacheCavityCatchModel( FiniteProbabilityModel model) { System.out.println("Toothache, Cavity, and Catch Model"); System.out.println("----------------------------------"); AssignmentProposition atoothache = new AssignmentProposition( ExampleRV.TOOTHACHE_RV, Boolean.TRUE); AssignmentProposition acavity = new AssignmentProposition( ExampleRV.CAVITY_RV, Boolean.TRUE); AssignmentProposition anotcavity = new AssignmentProposition( ExampleRV.CAVITY_RV, Boolean.FALSE); AssignmentProposition acatch = new AssignmentProposition( ExampleRV.CATCH_RV, Boolean.TRUE); // AIMA3e pg. 485 System.out.println("P(cavity) = " + model.prior(acavity)); System.out.println("P(cavity | toothache) = " + model.posterior(acavity, atoothache)); // AIMA3e pg. 492 DisjunctiveProposition cavityOrToothache = new DisjunctiveProposition( acavity, atoothache); System.out.println("P(cavity OR toothache) = " + model.prior(cavityOrToothache)); // AIMA3e pg. 493 System.out.println("P(~cavity | toothache) = " + model.posterior(anotcavity, atoothache)); // AIMA3e pg. 493 // P<>(Cavity | toothache) = <0.6, 0.4> System.out.println("P<>(Cavity | toothache) = " + model.posteriorDistribution(ExampleRV.CAVITY_RV, atoothache)); // AIMA3e pg. 497 // P<>(Cavity | toothache AND catch) = <0.871, 0.129> System.out.println("P<>(Cavity | toothache AND catch) = " + model.posteriorDistribution(ExampleRV.CAVITY_RV, atoothache, acatch)); } private static void demoBurglaryAlarmModel(FiniteProbabilityModel model) { System.out.println("--------------------"); System.out.println("Burglary Alarm Model"); System.out.println("--------------------"); AssignmentProposition aburglary = new AssignmentProposition( ExampleRV.BURGLARY_RV, Boolean.TRUE); AssignmentProposition anotburglary = new AssignmentProposition( ExampleRV.BURGLARY_RV, Boolean.FALSE); AssignmentProposition anotearthquake = new AssignmentProposition( ExampleRV.EARTHQUAKE_RV, Boolean.FALSE); AssignmentProposition aalarm = new AssignmentProposition( ExampleRV.ALARM_RV, Boolean.TRUE); AssignmentProposition anotalarm = new AssignmentProposition( ExampleRV.ALARM_RV, Boolean.FALSE); AssignmentProposition ajohnCalls = new AssignmentProposition( ExampleRV.JOHN_CALLS_RV, Boolean.TRUE); AssignmentProposition amaryCalls = new AssignmentProposition( ExampleRV.MARY_CALLS_RV, Boolean.TRUE); // AIMA3e pg. 514 System.out.println("P(j,m,a,~b,~e) = " + model.prior(ajohnCalls, amaryCalls, aalarm, anotburglary, anotearthquake)); System.out.println("P(j,m,~a,~b,~e) = " + model.prior(ajohnCalls, amaryCalls, anotalarm, anotburglary, anotearthquake)); // AIMA3e. pg. 514 // P<>(Alarm | JohnCalls = true, MaryCalls = true, Burglary = false, // Earthquake = false) // = <0.558, 0.442> System.out .println("P<>(Alarm | JohnCalls = true, MaryCalls = true, Burglary = false, Earthquake = false) = " + model.posteriorDistribution(ExampleRV.ALARM_RV, ajohnCalls, amaryCalls, anotburglary, anotearthquake)); // AIMA3e pg. 523 // P<>(Burglary | JohnCalls = true, MaryCalls = true) = <0.284, 0.716> System.out .println("P<>(Burglary | JohnCalls = true, MaryCalls = true) = " + model.posteriorDistribution(ExampleRV.BURGLARY_RV, ajohnCalls, amaryCalls)); // AIMA3e pg. 528 // P<>(JohnCalls | Burglary = true) System.out.println("P<>(JohnCalls | Burglary = true) = " + model.posteriorDistribution(ExampleRV.JOHN_CALLS_RV, aburglary)); } }