package dist.test; import dist.Distribution; import dist.DiscreteDistributionTable; import shared.DataSet; import shared.Instance; import dist.hmm.HiddenMarkovModelReestimator; import dist.hmm.ForwardBackwardProbabilityCalculator; import dist.hmm.ModularHiddenMarkovModel; import dist.hmm.SimpleStateDistributionTable; import dist.hmm.StateDistribution; import dist.hmm.StateSequenceCalculator; /** * A test class for running a simple wumpus world * example * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class HMMWumpusTest { /** The state count */ private static int STATE_COUNT = 2; /** The input range */ private static int INPUT_RANGE = 5; /** The smell left input */ private static int SMELL_LEFT = 0; /** The smell right input */ private static int SMELL_RIGHT = 1; /** The smell up input */ private static int SMELL_UP = 2; /** The smell down input */ private static int SMELL_DOWN = 3; /** The no smell input */ private static int NO_SMELL = 4; /** The output range */ private static int OUTPUT_RANGE = 4; /** The move left output */ private static int MOVE_LEFT = 0; /** The move right output */ private static int MOVE_RIGHT = 1; /** The move up output */ private static int MOVE_UP = 2; /** The move down output */ private static int MOVE_DOWN = 3; /** * The main method * @param args ignored */ public static void main(String[] args) { // simple wumpus world test ModularHiddenMarkovModel model = new ModularHiddenMarkovModel(STATE_COUNT); model.setOutputDistributions(new Distribution[] { DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE), DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE), }); model.setTransitionDistributions(new StateDistribution[] { new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()), new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()), }); model.setInitialStateDistribution( new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix())); Instance[] sequence = new Instance[] { new Instance(NO_SMELL, MOVE_UP), new Instance(SMELL_LEFT, MOVE_RIGHT), new Instance(SMELL_RIGHT, MOVE_LEFT), new Instance(SMELL_UP, MOVE_DOWN), new Instance(SMELL_DOWN, MOVE_UP) }; DataSet[] sequences = new DataSet[] { new DataSet(sequence), }; System.out.println(model + "\n"); System.out.println("Observation Sequences: "); for (int i = 0; i < sequences.length; i++) { System.out.println(sequences[i]); } System.out.println(); ForwardBackwardProbabilityCalculator fbc = new ForwardBackwardProbabilityCalculator(model, sequences[0]); System.out.println("Log probability of first sequence: "); System.out.println(fbc.calculateLogProbability()); System.out.println(); StateSequenceCalculator vc =new StateSequenceCalculator(model, sequences[0]); int[] states = vc.calculateStateSequence(); System.out.println("Most likely state sequence of first sequence: "); for (int i = 0; i < states.length; i++) { System.out.print(states[i] + " "); } System.out.println(); System.out.println(); System.out.println("Reestimations of model based on sequences: "); HiddenMarkovModelReestimator bwr = new HiddenMarkovModelReestimator(model, sequences); bwr.train(); System.out.println(model + "\n"); bwr.train(); System.out.println(model + "\n"); for (int i = 0; i < 20; i++) { bwr.train(); } System.out.println(model + "\n"); fbc = new ForwardBackwardProbabilityCalculator(model, sequences[0]); System.out.println("Log probability of first sequence: "); System.out.println(fbc.calculateLogProbability()); System.out.println("Log probabilities of other sequences: "); for (int i = 1; i < sequences.length; i++) { fbc = new ForwardBackwardProbabilityCalculator(model, sequences[i]); System.out.println(fbc.calculateLogProbability()); } } }