package dist.test;
import java.util.Random;
import dist.Distribution;
import dist.DiscreteDistributionTable;
import shared.ConvergenceTrainer;
import shared.DataSet;
import shared.Instance;
import dist.hmm.HiddenMarkovModelReestimator;
import dist.hmm.ForwardBackwardProbabilityCalculator;
import dist.hmm.ModularHiddenMarkovModel;
import dist.hmm.SimpleStateDistributionTable;
import dist.hmm.StateDistribution;
/**
* A test class for testing long term dependencies
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class HMMConditionalMonsterTest {
/** The sequence count */
private static int SEQUENCE_COUNT = 5;
/** The sequence count */
private static int SEQUENCE_LENGTH = 100;
/** The state count */
private static int STATE_COUNT = 4;
/** The input range */
private static int INPUT_RANGE = 4;
/** The smell left input */
private static int SMELL_DAY = 0;
/** The no smell input */
private static int NO_SMELL_DAY = 1;
/** The smell left input */
private static int SMELL_NIGHT = 2;
/** The no smell input */
private static int NO_SMELL_NIGHT = 3;
/** The output range */
private static int OUTPUT_RANGE = 4;
/** The move left output */
private static int RUN_AWAY = 0;
/** The move right output */
private static int RUN_TOWARDS = 1;
/** The move up output */
private static int STAY_STILL = 2;
/** The move up output */
private static int SLEEP = 3;
/**
* The main method
* @param args ignored
*/
public static void main(String[] args) {
int count = 0;
int goodCount = 0;
int iterations = 0;
while (count < 1000) {
// simple wumpus world test
ModularHiddenMarkovModel model = new ModularHiddenMarkovModel(STATE_COUNT);
model.setOutputDistributions(new Distribution[] {
DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE),
DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE),
DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE),
DiscreteDistributionTable.random(INPUT_RANGE, OUTPUT_RANGE),});
model.setTransitionDistributions(new StateDistribution[] {
new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()),
new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()),
new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()),
new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()),
});
model.setInitialStateDistribution(
new SimpleStateDistributionTable(DiscreteDistributionTable.random(INPUT_RANGE, STATE_COUNT).getProbabilityMatrix()));
Instance[][] sequences = new Instance[SEQUENCE_COUNT][];
Random random = new Random();
for (int i = 0; i < sequences.length; i++) {
sequences[i] = new Instance[SEQUENCE_LENGTH];
boolean smellSomething = random.nextBoolean();
boolean day = random.nextBoolean();
boolean isHungry = true;
double smellProbability = random.nextDouble();
double dayProbability = random.nextDouble();
for (int j = 0; j < sequences[i].length; j++) {
if (smellSomething && isHungry) {
if (day) {
sequences[i][j] = new Instance(SMELL_DAY);
sequences[i][j].setLabel(new Instance(RUN_TOWARDS));
} else {
sequences[i][j] = new Instance(SMELL_NIGHT);
sequences[i][j].setLabel(new Instance(RUN_TOWARDS));
}
} else if (smellSomething && !isHungry) {
if (day) {
sequences[i][j] = new Instance(SMELL_DAY);
sequences[i][j].setLabel(new Instance(RUN_AWAY));
} else {
sequences[i][j] = new Instance(SMELL_NIGHT);
sequences[i][j].setLabel(new Instance(RUN_AWAY));
}
} else {
if (day) {
sequences[i][j] = new Instance(NO_SMELL_DAY);
sequences[i][j].setLabel(new Instance(STAY_STILL));
} else {
sequences[i][j] = new Instance(NO_SMELL_NIGHT);
sequences[i][j].setLabel(new Instance(SLEEP));
}
}
if (random.nextDouble() < smellProbability) {
smellProbability = random.nextDouble();
if (smellSomething) {
smellSomething = false;
isHungry = !isHungry;
} else {
smellSomething = true;
}
}
if (random.nextDouble() < dayProbability) {
dayProbability = random.nextDouble();
day = !day;
}
}
}
DataSet[] dataSets = new DataSet[sequences.length];
for (int i = 0; i < dataSets.length; i++) {
dataSets[i] = new DataSet(sequences[i]);
}
System.out.println("Reestimations of model based on sequences: ");
HiddenMarkovModelReestimator bwr = new HiddenMarkovModelReestimator(model, dataSets);
ConvergenceTrainer trainer = new ConvergenceTrainer(bwr);
trainer.train();
iterations += trainer.getIterations();
System.out.println(model + "\n");
System.out.println("Log probabilities of sequences: ");
boolean success = true;
for (int i = 0; i < sequences.length; i++) {
ForwardBackwardProbabilityCalculator fbc = new ForwardBackwardProbabilityCalculator(model, dataSets[i]);
System.out.println(fbc.calculateLogProbability());
if (success && fbc.calculateLogProbability() < -0.01) {
success = false;
System.out.println("FAILURE");
}
}
if (success) {
goodCount++;
}
count++;
System.out.println("So Far " + goodCount + " / " + count);
System.out.println(iterations + " iterations");
}
System.out.println(goodCount + " / " + count);
System.out.println(iterations + " iterations");
}
}