package aima.gui.demo.learning; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import aima.core.environment.cellworld.Cell; import aima.core.environment.cellworld.CellWorld; import aima.core.environment.cellworld.CellWorldAction; import aima.core.environment.cellworld.CellWorldFactory; import aima.core.learning.framework.DataSet; import aima.core.learning.framework.DataSetFactory; import aima.core.learning.framework.Learner; import aima.core.learning.inductive.DLTestFactory; import aima.core.learning.inductive.DecisionTree; import aima.core.learning.learners.AdaBoostLearner; import aima.core.learning.learners.DecisionListLearner; import aima.core.learning.learners.DecisionTreeLearner; import aima.core.learning.learners.StumpLearner; import aima.core.learning.neural.BackPropLearning; import aima.core.learning.neural.FeedForwardNeuralNetwork; import aima.core.learning.neural.IrisDataSetNumerizer; import aima.core.learning.neural.IrisNNDataSet; import aima.core.learning.neural.NNConfig; import aima.core.learning.neural.NNDataSet; import aima.core.learning.neural.Numerizer; import aima.core.learning.neural.Perceptron; import aima.core.learning.reinforcement.agent.PassiveADPAgent; import aima.core.learning.reinforcement.agent.PassiveTDAgent; import aima.core.learning.reinforcement.agent.QLearningAgent; import aima.core.learning.reinforcement.agent.ReinforcementAgent; import aima.core.learning.reinforcement.example.CellWorldEnvironment; import aima.core.probability.example.MDPFactory; import aima.core.probability.mdp.impl.ModifiedPolicyEvaluation; import aima.core.util.JavaRandomizer; import aima.core.util.Util; public class LearningDemo { public static void main(String[] args) { // Chapter 18 decisionTreeDemo(); decisionListDemo(); ensembleLearningDemo(); perceptronDemo(); backPropogationDemo(); // Chapter 21 passiveADPAgentDemo(); passiveTDAgentDemo(); qLearningAgentDemo(); } public static void decisionTreeDemo() { System.out.println(Util.ntimes("*", 100)); System.out .println("\nDecisionTree Demo - Inducing a DecisionList from the Restaurant DataSet\n "); System.out.println(Util.ntimes("*", 100)); try { DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionTreeLearner learner = new DecisionTreeLearner(); learner.train(ds); System.out.println("The Induced Decision Tree is "); System.out.println(learner.getDecisionTree()); int[] result = learner.test(ds); System.out .println("\nThis Decision Tree classifies the data set with " + result[0] + " successes" + " and " + result[1] + " failures"); System.out.println("\n"); } catch (Exception e) { System.out.println("Decision Tree Demo Failed "); e.printStackTrace(); } } public static void decisionListDemo() { try { System.out.println(Util.ntimes("*", 100)); System.out .println("DecisionList Demo - Inducing a DecisionList from the Restaurant DataSet\n "); System.out.println(Util.ntimes("*", 100)); DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionListLearner learner = new DecisionListLearner("Yes", "No", new DLTestFactory()); learner.train(ds); System.out.println("The Induced DecisionList is"); System.out.println(learner.getDecisionList()); int[] result = learner.test(ds); System.out .println("\nThis Decision List classifies the data set with " + result[0] + " successes" + " and " + result[1] + " failures"); System.out.println("\n"); } catch (Exception e) { System.out.println("Decision ListDemo Failed"); } } public static void ensembleLearningDemo() { System.out.println(Util.ntimes("*", 100)); System.out .println("\n Ensemble Decision Demo - Weak Learners co operating to give Superior decisions "); System.out.println(Util.ntimes("*", 100)); try { DataSet ds = DataSetFactory.getRestaurantDataSet(); List<DecisionTree> stumps = DecisionTree.getStumpsFor(ds, "Yes", "No"); List<Learner> learners = new ArrayList<Learner>(); System.out .println("\nStump Learners vote to decide in this algorithm"); for (Object stump : stumps) { DecisionTree sl = (DecisionTree) stump; StumpLearner stumpLearner = new StumpLearner(sl, "No"); learners.add(stumpLearner); } AdaBoostLearner learner = new AdaBoostLearner(learners, ds); learner.train(ds); int[] result = learner.test(ds); System.out .println("\nThis Ensemble Learner classifies the data set with " + result[0] + " successes" + " and " + result[1] + " failures"); System.out.println("\n"); } catch (Exception e) { } } public static void perceptronDemo() { try { System.out.println(Util.ntimes("*", 100)); System.out .println("\n Perceptron Demo - Running Perceptron on Iris data Set with 10 epochs of learning "); System.out.println(Util.ntimes("*", 100)); DataSet irisDataSet = DataSetFactory.getIrisDataSet(); Numerizer numerizer = new IrisDataSetNumerizer(); NNDataSet innds = new IrisNNDataSet(); innds.createExamplesFromDataSet(irisDataSet, numerizer); Perceptron perc = new Perceptron(3, 4); perc.trainOn(innds, 10); innds.refreshDataset(); int[] result = perc.testOnDataSet(innds); System.out.println(result[0] + " right, " + result[1] + " wrong"); } catch (Exception e) { e.printStackTrace(); } } public static void backPropogationDemo() { try { System.out.println(Util.ntimes("*", 100)); System.out .println("\n BackpropagationDemo - Running BackProp on Iris data Set with 10 epochs of learning "); System.out.println(Util.ntimes("*", 100)); DataSet irisDataSet = DataSetFactory.getIrisDataSet(); Numerizer numerizer = new IrisDataSetNumerizer(); NNDataSet innds = new IrisNNDataSet(); innds.createExamplesFromDataSet(irisDataSet, numerizer); NNConfig config = new NNConfig(); config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_INPUTS, 4); config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_OUTPUTS, 3); config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_HIDDEN_NEURONS, 6); config.setConfig(FeedForwardNeuralNetwork.LOWER_LIMIT_WEIGHTS, -2.0); config.setConfig(FeedForwardNeuralNetwork.UPPER_LIMIT_WEIGHTS, 2.0); FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(config); ffnn.setTrainingScheme(new BackPropLearning(0.1, 0.9)); ffnn.trainOn(innds, 10); innds.refreshDataset(); int[] result = ffnn.testOnDataSet(innds); System.out.println(result[0] + " right, " + result[1] + " wrong"); } catch (Exception e) { e.printStackTrace(); } } public static void passiveADPAgentDemo() { System.out.println("======================="); System.out.println("DEMO: Passive-ADP-Agent"); System.out.println("======================="); System.out.println("Figure 21.3"); System.out.println("-----------"); CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); Map<Cell<Double>, CellWorldAction> fixedPolicy = new HashMap<Cell<Double>, CellWorldAction>(); fixedPolicy.put(cw.getCellAt(1, 1), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(1, 2), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(1, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(2, 1), CellWorldAction.Left); fixedPolicy.put(cw.getCellAt(2, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(3, 1), CellWorldAction.Left); fixedPolicy.put(cw.getCellAt(3, 2), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(3, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(4, 1), CellWorldAction.Left); PassiveADPAgent<Cell<Double>, CellWorldAction> padpa = new PassiveADPAgent<Cell<Double>, CellWorldAction>( fixedPolicy, cw.getCells(), cw.getCellAt(1, 1), MDPFactory.createActionsFunctionForFigure17_1(cw), new ModifiedPolicyEvaluation<Cell<Double>, CellWorldAction>(10, 1.0)); cwe.addAgent(padpa); output_utility_learning_rates(padpa, 20, 100, 100, 1); System.out.println("========================="); } public static void passiveTDAgentDemo() { System.out.println("======================"); System.out.println("DEMO: Passive-TD-Agent"); System.out.println("======================"); System.out.println("Figure 21.5"); System.out.println("-----------"); CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); Map<Cell<Double>, CellWorldAction> fixedPolicy = new HashMap<Cell<Double>, CellWorldAction>(); fixedPolicy.put(cw.getCellAt(1, 1), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(1, 2), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(1, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(2, 1), CellWorldAction.Left); fixedPolicy.put(cw.getCellAt(2, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(3, 1), CellWorldAction.Left); fixedPolicy.put(cw.getCellAt(3, 2), CellWorldAction.Up); fixedPolicy.put(cw.getCellAt(3, 3), CellWorldAction.Right); fixedPolicy.put(cw.getCellAt(4, 1), CellWorldAction.Left); PassiveTDAgent<Cell<Double>, CellWorldAction> ptda = new PassiveTDAgent<Cell<Double>, CellWorldAction>( fixedPolicy, 0.2, 1.0); cwe.addAgent(ptda); output_utility_learning_rates(ptda, 20, 500, 100, 1); System.out.println("========================="); } public static void qLearningAgentDemo() { System.out.println("======================"); System.out.println("DEMO: Q-Learning-Agent"); System.out.println("======================"); CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); QLearningAgent<Cell<Double>, CellWorldAction> qla = new QLearningAgent<Cell<Double>, CellWorldAction>( MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0); cwe.addAgent(qla); output_utility_learning_rates(qla, 20, 10000, 500, 20); System.out.println("========================="); } // // PRIVATE METHODS // private static void output_utility_learning_rates( ReinforcementAgent<Cell<Double>, CellWorldAction> reinforcementAgent, int numRuns, int numTrialsPerRun, int rmseTrialsToReport, int reportEveryN) { if (rmseTrialsToReport > (numTrialsPerRun / reportEveryN)) { throw new IllegalArgumentException( "Requesting to report too many RMSE trials, max allowed for args is " + (numTrialsPerRun / reportEveryN)); } CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1(); CellWorldEnvironment cwe = new CellWorldEnvironment( cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer()); cwe.addAgent(reinforcementAgent); Map<Integer, List<Map<Cell<Double>, Double>>> runs = new HashMap<Integer, List<Map<Cell<Double>, Double>>>(); for (int r = 0; r < numRuns; r++) { reinforcementAgent.reset(); List<Map<Cell<Double>, Double>> trials = new ArrayList<Map<Cell<Double>, Double>>(); for (int t = 0; t < numTrialsPerRun; t++) { cwe.executeTrial(); if (0 == t % reportEveryN) { Map<Cell<Double>, Double> u = reinforcementAgent .getUtility(); if (null == u.get(cw.getCellAt(1, 1))) { throw new IllegalStateException( "Bad Utility State Encountered: r=" + r + ", t=" + t + ", u=" + u); } trials.add(u); } } runs.put(r, trials); } StringBuilder v4_3 = new StringBuilder(); StringBuilder v3_3 = new StringBuilder(); StringBuilder v1_3 = new StringBuilder(); StringBuilder v1_1 = new StringBuilder(); StringBuilder v3_2 = new StringBuilder(); StringBuilder v2_1 = new StringBuilder(); for (int t = 0; t < (numTrialsPerRun / reportEveryN); t++) { // Use the last run Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t); v4_3.append((u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw .getCellAt(4, 3)) : 0.0) + "\t"); v3_3.append((u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw .getCellAt(3, 3)) : 0.0) + "\t"); v1_3.append((u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw .getCellAt(1, 3)) : 0.0) + "\t"); v1_1.append((u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw .getCellAt(1, 1)) : 0.0) + "\t"); v3_2.append((u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw .getCellAt(3, 2)) : 0.0) + "\t"); v2_1.append((u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw .getCellAt(2, 1)) : 0.0) + "\t"); } StringBuilder rmseValues = new StringBuilder(); for (int t = 0; t < rmseTrialsToReport; t++) { // Calculate the Root Mean Square Error for utility of 1,1 // for this trial# across all runs double xSsquared = 0; for (int r = 0; r < numRuns; r++) { Map<Cell<Double>, Double> u = runs.get(r).get(t); Double val1_1 = u.get(cw.getCellAt(1, 1)); if (null == val1_1) { throw new IllegalStateException( "U(1,1,) is not present: r=" + r + ", t=" + t + ", runs.size=" + runs.size() + ", runs(r).size()=" + runs.get(r).size() + ", u=" + u); } xSsquared += Math.pow(0.705 - val1_1, 2); } double rmse = Math.sqrt(xSsquared / runs.size()); rmseValues.append(rmse); rmseValues.append("\t"); } System.out .println("Note: You may copy and paste the following lines into a spreadsheet to generate graphs of learning rate and RMS error in utility:"); System.out.println("(4,3)" + "\t" + v4_3); System.out.println("(3,3)" + "\t" + v3_3); System.out.println("(1,3)" + "\t" + v1_3); System.out.println("(1,1)" + "\t" + v1_1); System.out.println("(3,2)" + "\t" + v3_2); System.out.println("(2,1)" + "\t" + v2_1); System.out.println("RMSeiu" + "\t" + rmseValues); } }