/**
* Copyright (C) 2015-2016, BMW Car IT GmbH and BMW AG
* Author: Stefan Holder (stefan.holder@bmw.de)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.bmw.hmm;
import static java.lang.Math.log;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import com.bmw.hmm.SequenceState;
import com.bmw.hmm.Transition;
import com.bmw.hmm.ViterbiAlgorithm;
public class ViterbiAlgorithmTest {
private static class Rain {
final static Rain T = new Rain();
final static Rain F = new Rain();
@Override
public String toString() {
if (this == T) {
return "Rain";
} else if (this == F) {
return "Sun";
}
throw new IllegalStateException();
}
}
private static class Umbrella {
final static Umbrella T = new Umbrella();
final static Umbrella F = new Umbrella();
@Override
public String toString() {
if (this == T) {
return "Umbrella";
} else if (this == F) {
return "No umbrella";
}
throw new IllegalStateException();
}
}
private static class Descriptor {
final static Descriptor R2R = new Descriptor();
final static Descriptor R2S = new Descriptor();
final static Descriptor S2R = new Descriptor();
final static Descriptor S2S = new Descriptor();
@Override
public String toString() {
if (this == R2R) {
return "R2R";
} else if (this == R2S) {
return "R2S";
} else if (this == S2R) {
return "S2R";
} else if (this == S2S) {
return "S2S";
}
throw new IllegalStateException();
}
}
private static double DELTA = 1e-8;
private List<Rain> states(List<SequenceState<Rain, Umbrella, Descriptor>> sequenceStates) {
final List<Rain> result = new ArrayList<>();
for (SequenceState<Rain, Umbrella, Descriptor> ss : sequenceStates) {
result.add(ss.state);
}
return result;
}
/**
* Tests the Viterbi algorithms with the umbrella example taken from Russell, Norvig: Aritifical
* Intelligence - A Modern Approach, 3rd edition, chapter 15.2.3. Note that the probabilities in
* Figure 15.5 are different, since the book uses initial probabilities and the probabilities
* for message m1:1 are normalized (not wrong but unnecessary).
*/
@Test
public void testComputeMostLikelySequence() {
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
final Map<Rain, Double> emissionLogProbabilitiesForUmbrella = new LinkedHashMap<>();
emissionLogProbabilitiesForUmbrella.put(Rain.T, log(0.9));
emissionLogProbabilitiesForUmbrella.put(Rain.F, log(0.2));
final Map<Rain, Double> emissionLogProbabilitiesForNoUmbrella = new LinkedHashMap<>();
emissionLogProbabilitiesForNoUmbrella.put(Rain.T, log(0.1));
emissionLogProbabilitiesForNoUmbrella.put(Rain.F, log(0.8));
final Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.7));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.3));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.3));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.7));
final Map<Transition<Rain>, Descriptor> transitionDescriptors = new LinkedHashMap<>();
transitionDescriptors.put(new Transition<Rain>(Rain.T, Rain.T), Descriptor.R2R);
transitionDescriptors.put(new Transition<Rain>(Rain.T, Rain.F), Descriptor.R2S);
transitionDescriptors.put(new Transition<Rain>(Rain.F, Rain.T), Descriptor.S2R);
transitionDescriptors.put(new Transition<Rain>(Rain.F, Rain.F), Descriptor.S2S);
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>(true);
viterbi.startWithInitialObservation(Umbrella.T, candidates,
emissionLogProbabilitiesForUmbrella);
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilitiesForUmbrella,
transitionLogProbabilities, transitionDescriptors);
viterbi.nextStep(Umbrella.F, candidates, emissionLogProbabilitiesForNoUmbrella,
transitionLogProbabilities, transitionDescriptors);
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilitiesForUmbrella,
transitionLogProbabilities, transitionDescriptors);
final List<SequenceState<Rain, Umbrella, Descriptor>> result =
viterbi.computeMostLikelySequence();
// Check most likely sequence
assertEquals(4, result.size());
assertEquals(Rain.T, result.get(0).state);
assertEquals(Rain.T, result.get(1).state);
assertEquals(Rain.F, result.get(2).state);
assertEquals(Rain.T, result.get(3).state);
assertEquals(Umbrella.T, result.get(0).observation);
assertEquals(Umbrella.T, result.get(1).observation);
assertEquals(Umbrella.F, result.get(2).observation);
assertEquals(Umbrella.T, result.get(3).observation);
assertEquals(null, result.get(0).transitionDescriptor);
assertEquals(Descriptor.R2R, result.get(1).transitionDescriptor);
assertEquals(Descriptor.R2S, result.get(2).transitionDescriptor);
assertEquals(Descriptor.S2R, result.get(3).transitionDescriptor);
// Check for HMM breaks
assertFalse(viterbi.isBroken());
// Check message history
List<Map<Rain, Double>> expectedMessageHistory = new ArrayList<>();
Map<Rain, Double> message = new LinkedHashMap<>();
message.put(Rain.T, 0.9);
message.put(Rain.F, 0.2);
expectedMessageHistory.add(message);
message = new LinkedHashMap<>();
message.put(Rain.T, 0.567);
message.put(Rain.F, 0.054);
expectedMessageHistory.add(message);
message = new LinkedHashMap<>();
message.put(Rain.T, 0.03969);
message.put(Rain.F, 0.13608);
expectedMessageHistory.add(message);
message = new LinkedHashMap<>();
message.put(Rain.T, 0.0367416);
message.put(Rain.F, 0.0190512);
expectedMessageHistory.add(message);
List<Map<Rain, Double>> actualMessageHistory = viterbi.messageHistory();
checkMessageHistory(expectedMessageHistory, actualMessageHistory);
}
private void checkMessageHistory(List<Map<Rain, Double>> expectedMessageHistory,
List<Map<Rain, Double>> actualMessageHistory) {
assertEquals(expectedMessageHistory.size(), actualMessageHistory.size());
for (int i = 0 ; i < expectedMessageHistory.size() ; i++) {
checkMessage(expectedMessageHistory.get(i), actualMessageHistory.get(i));
}
}
private void checkMessage(Map<Rain, Double> expectedMessage, Map<Rain, Double> actualMessage) {
assertEquals(expectedMessage.size(), actualMessage.size());
for (Map.Entry<Rain, Double> entry : expectedMessage.entrySet()) {
assertEquals(entry.getValue(), Math.exp(actualMessage.get(entry.getKey())), DELTA);
}
}
@Test
public void testEmptySequence() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
final List<SequenceState<Rain, Umbrella, Descriptor>> result =
viterbi.computeMostLikelySequence();
assertEquals(Arrays.asList(), result);
assertFalse(viterbi.isBroken());
}
@Test
public void testBreakAtInitialMessage() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
final Map<Rain, Double> emissionLogProbabilities = new LinkedHashMap<>();
emissionLogProbabilities.put(Rain.T, log(0.0));
emissionLogProbabilities.put(Rain.F, log(0.0));
viterbi.startWithInitialObservation(Umbrella.T, candidates, emissionLogProbabilities);
assertTrue(viterbi.isBroken());
assertEquals(Arrays.asList(), viterbi.computeMostLikelySequence());
}
@Test
public void testEmptyInitialMessage() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
viterbi.startWithInitialObservation(Umbrella.T, new ArrayList<Rain>(),
new LinkedHashMap<Rain, Double>());
assertTrue(viterbi.isBroken());
assertEquals(Arrays.asList(), viterbi.computeMostLikelySequence());
}
@Test
public void testBreakAtFirstTransition() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
final Map<Rain, Double> emissionLogProbabilities = new LinkedHashMap<>();
emissionLogProbabilities.put(Rain.T, log(0.9));
emissionLogProbabilities.put(Rain.F, log(0.2));
viterbi.startWithInitialObservation(Umbrella.T, candidates, emissionLogProbabilities);
assertFalse(viterbi.isBroken());
final Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.0));
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilities,
transitionLogProbabilities);
assertTrue(viterbi.isBroken());
assertEquals(Arrays.asList(Rain.T), states(viterbi.computeMostLikelySequence()));
}
@Test
public void testBreakAtFirstTransitionWithNoCandidates() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
final Map<Rain, Double> emissionLogProbabilities = new LinkedHashMap<>();
emissionLogProbabilities.put(Rain.T, log(0.9));
emissionLogProbabilities.put(Rain.F, log(0.2));
viterbi.startWithInitialObservation(Umbrella.T, candidates, emissionLogProbabilities);
assertFalse(viterbi.isBroken());
viterbi.nextStep(Umbrella.T, new ArrayList<Rain>(), new LinkedHashMap<Rain, Double>(),
new LinkedHashMap<Transition<Rain>, Double>());
assertTrue(viterbi.isBroken());
assertEquals(Arrays.asList(Rain.T), states(viterbi.computeMostLikelySequence()));
}
@Test
public void testBreakAtSecondTransition() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
final Map<Rain, Double> emissionLogProbabilities = new LinkedHashMap<>();
emissionLogProbabilities.put(Rain.T, log(0.9));
emissionLogProbabilities.put(Rain.F, log(0.2));
viterbi.startWithInitialObservation(Umbrella.T, candidates, emissionLogProbabilities);
assertFalse(viterbi.isBroken());
Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.5));
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilities,
transitionLogProbabilities);
assertFalse(viterbi.isBroken());
transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.0));
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilities,
transitionLogProbabilities);
assertTrue(viterbi.isBroken());
assertEquals(Arrays.asList(Rain.T, Rain.T), states(viterbi.computeMostLikelySequence()));
}
@Test
/**
* Checks if the first candidate is returned if multiple candidates are equally likely.
*/
public void testDeterministicCandidateOrder() {
final List<Rain> candidates = new ArrayList<>();
candidates.add(Rain.T);
candidates.add(Rain.F);
// Reverse usual order of emission and transition probabilities keys since their order
// should not matter.
final Map<Rain, Double> emissionLogProbabilitiesForUmbrella = new LinkedHashMap<>();
emissionLogProbabilitiesForUmbrella.put(Rain.F, log(0.5));
emissionLogProbabilitiesForUmbrella.put(Rain.T, log(0.5));
final Map<Rain, Double> emissionLogProbabilitiesForNoUmbrella = new LinkedHashMap<>();
emissionLogProbabilitiesForNoUmbrella.put(Rain.F, log(0.5));
emissionLogProbabilitiesForNoUmbrella.put(Rain.T, log(0.5));
final Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.5));
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>(true);
viterbi.startWithInitialObservation(Umbrella.T, candidates,
emissionLogProbabilitiesForUmbrella);
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilitiesForUmbrella,
transitionLogProbabilities);
viterbi.nextStep(Umbrella.F, candidates, emissionLogProbabilitiesForNoUmbrella,
transitionLogProbabilities);
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilitiesForUmbrella,
transitionLogProbabilities);
final List<SequenceState<Rain, Umbrella, Descriptor>> result =
viterbi.computeMostLikelySequence();
// Check most likely sequence
assertEquals(4, result.size());
assertEquals(Rain.T, result.get(0).state);
assertEquals(Rain.T, result.get(1).state);
assertEquals(Rain.T, result.get(2).state);
assertEquals(Rain.T, result.get(3).state);
}
}