/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.sequence; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; import smile.math.Math; import smile.stat.distribution.EmpiricalDistribution; /** * * @author Haifeng Li */ @SuppressWarnings({"rawtypes", "unchecked"}) public class HMMTest { double[] pi = {0.5, 0.5}; double[][] a = {{0.8, 0.2}, {0.2, 0.8}}; double[][] b = {{0.6, 0.4}, {0.4, 0.6}}; public HMMTest() { } @BeforeClass public static void setUpClass() throws Exception { } @AfterClass public static void tearDownClass() throws Exception { } @Before public void setUp() { } @After public void tearDown() { } /** * Test of numStates method, of class HMM. */ @Test public void testNumStates() { System.out.println("numStates"); HMM hmm = new HMM(pi, a, b); int expResult = 2; int result = hmm.numStates(); assertEquals(expResult, result); } /** * Test of numSymbols method, of class HMM. */ @Test public void testNumSymbols() { System.out.println("numSymbols"); HMM hmm = new HMM(pi, a, b); int expResult = 2; int result = hmm.numSymbols(); assertEquals(expResult, result); } /** * Test of getInitialStateProbabilities method, of class HMM. */ @Test public void testGetInitialStateProbabilities() { System.out.println("getInitialStateProbabilities"); HMM hmm = new HMM(pi, a, b); double[] expResult = pi; double[] result = hmm.getInitialStateProbabilities(); for (int i = 0; i < expResult.length; i++) { assertEquals(expResult[i], result[i], 1E-7); } } /** * Test of getStateTransitionProbabilities method, of class HMM. */ @Test public void testGetStateTransitionProbabilities() { System.out.println("getStateTransitionProbabilities"); HMM hmm = new HMM(pi, a, b); double[][] expResult = a; double[][] result = hmm.getStateTransitionProbabilities(); for (int i = 0; i < expResult.length; i++) { for (int j = 0; j < expResult[i].length; j++) { assertEquals(expResult[i][j], result[i][j], 1E-7); } } } /** * Test of getSymbolEmissionProbabilities method, of class HMM. */ @Test public void testGetSymbolEmissionProbabilities() { System.out.println("getSymbolEmissionProbabilities"); HMM hmm = new HMM(pi, a, b); double[][] expResult = b; double[][] result = hmm.getSymbolEmissionProbabilities(); for (int i = 0; i < expResult.length; i++) { for (int j = 0; j < expResult[i].length; j++) { assertEquals(expResult[i][j], result[i][j], 1E-7); } } } /** * Test of p method, of class HMM. */ @Test public void testP_intArr_intArr() { System.out.println("p"); int[] o = {0, 0, 1, 1, 0, 1, 1, 0}; int[] s = {0, 0, 1, 1, 1, 1, 1, 0}; HMM hmm = new HMM(pi, a, b); double expResult = 7.33836e-05; double result = hmm.p(o, s); assertEquals(expResult, result, 1E-10); } /** * Test of logp method, of class HMM. */ @Test public void testLogp_intArr_intArr() { System.out.println("logp"); HMM hmm = new HMM(pi, a, b); int[] o = {0, 0, 1, 1, 0, 1, 1, 0}; int[] s = {0, 0, 1, 1, 1, 1, 1, 0}; double expResult = -9.51981; double result = hmm.logp(o, s); assertEquals(expResult, result, 1E-5); } /** * Test of p method, of class HMM. */ @Test public void testP_intArr() { System.out.println("p"); HMM hmm = new HMM(pi, a, b); int[] o = {0, 0, 1, 1, 0, 1, 1, 0}; double expResult = 0.003663364; double result = hmm.p(o); assertEquals(expResult, result, 1E-9); } /** * Test of logp method, of class HMM. */ @Test public void testLogp_intArr() { System.out.println("logp"); HMM hmm = new HMM(pi, a, b); int[] o = {0, 0, 1, 1, 0, 1, 1, 0}; double expResult = -5.609373; double result = hmm.logp(o); assertEquals(expResult, result, 1E-6); } /** * Test of predict method, of class HMM. */ @Test public void testPredict() { System.out.println("predict"); HMM hmm = new HMM(pi, a, b); int[] o = {0, 0, 1, 1, 0, 1, 1, 0}; int[] s = {0, 0, 0, 0, 0, 0, 0, 0}; int[] result = hmm.predict(o); assertEquals(o.length, result.length); for (int i = 0; i < s.length; i++) { assertEquals(s[i], result[i]); } } /** * Test of learn method, of class HMM. */ @Test public void testLearn() { System.out.println("learn"); EmpiricalDistribution initial = new EmpiricalDistribution(pi); EmpiricalDistribution[] transition = new EmpiricalDistribution[a.length]; for (int i = 0; i < transition.length; i++) { transition[i] = new EmpiricalDistribution(a[i]); } EmpiricalDistribution[] emission = new EmpiricalDistribution[b.length]; for (int i = 0; i < emission.length; i++) { emission[i] = new EmpiricalDistribution(b[i]); } int[][] sequences = new int[5000][]; int[][] labels = new int[5000][]; for (int i = 0; i < sequences.length; i++) { sequences[i] = new int[30 * (Math.randomInt(5) + 1)]; labels[i] = new int[sequences[i].length]; int state = (int) initial.rand(); sequences[i][0] = (int) emission[state].rand(); labels[i][0] = state; for (int j = 1; j < sequences[i].length; j++) { state = (int) transition[state].rand(); sequences[i][j] = (int) emission[state].rand(); labels[i][j] = state; } } HMM hmm = new HMM(sequences, labels); System.out.println(hmm); double[] pi2 = {0.55, 0.45}; double[][] a2 = {{0.7, 0.3}, {0.15, 0.85}}; double[][] b2 = {{0.45, 0.55}, {0.3, 0.7}}; HMM init = new HMM(pi2, a2, b2); HMM result = init.learn(sequences, 100); System.out.println(result); } /** * Test of p method, of class HMM. */ @Test public void testP_intArr_intArr2() { System.out.println("p"); String[] symbols = {"0", "1"}; HMM<String> hmm = new HMM<>(pi, a, b, symbols); String[] o = {"0", "0", "1", "1", "0", "1", "1", "0"}; int[] s = {0, 0, 1, 1, 1, 1, 1, 0}; double expResult = 7.33836e-05; double result = hmm.p(o, s); assertEquals(expResult, result, 1E-10); } /** * Test of logp method, of class HMM. */ @Test public void testLogp_intArr_intArr2() { System.out.println("logp"); String[] symbols = {"0", "1"}; HMM<String> hmm = new HMM<>(pi, a, b, symbols); String[] o = {"0", "0", "1", "1", "0", "1", "1", "0"}; int[] s = {0, 0, 1, 1, 1, 1, 1, 0}; double expResult = -9.51981; double result = hmm.logp(o, s); assertEquals(expResult, result, 1E-5); } /** * Test of p method, of class HMM. */ @Test public void testP_intArr2() { System.out.println("p"); String[] symbols = {"0", "1"}; HMM<String> hmm = new HMM<>(pi, a, b, symbols); String[] o = {"0", "0", "1", "1", "0", "1", "1", "0"}; double expResult = 0.003663364; double result = hmm.p(o); assertEquals(expResult, result, 1E-9); } /** * Test of logp method, of class HMM. */ @Test public void testLogp_intArr2() { System.out.println("logp"); String[] symbols = {"0", "1"}; HMM<String> hmm = new HMM<>(pi, a, b, symbols); String[] o = {"0", "0", "1", "1", "0", "1", "1", "0"}; double expResult = -5.609373; double result = hmm.logp(o); assertEquals(expResult, result, 1E-6); } /** * Test of predict method, of class HMM. */ @Test public void testPredict2() { System.out.println("predict"); String[] symbols = {"0", "1"}; HMM<String> hmm = new HMM<>(pi, a, b, symbols); String[] o = {"0", "0", "1", "1", "0", "1", "1", "0"}; int[] s = {0, 0, 0, 0, 0, 0, 0, 0}; int[] result = hmm.predict(o); assertEquals(o.length, result.length); for (int i = 0; i < s.length; i++) { assertEquals(s[i], result[i]); } } /** * Test of predict method, of class HMM. */ @Test public void testPredict3() { System.out.println("predict"); String[] symbols = {"H", "T", "P"}; double[] pi2 = {0.4, 0.3, 0.3}; double[][] a2 = { {0.3, 0.4, 0.3}, {0.3, 0.3, 0.4}, {0.4, 0.2, 0.4} }; double[][] b2 = { {0.4, 0.3, 0.3}, {0.5, 0.2, 0.3}, {0.2, 0.3, 0.5} }; HMM<String> hmm = new HMM<>(pi2, a2, b2, symbols); String[] o = {"H", "H", "P", "P", "P", "H", "H", "H", "P", "P", "P", "H", "T", "T", "T"}; int[] s = {0, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 0, 2, 2, 0}; int[] result = hmm.predict(o); assertEquals(o.length, result.length); for (int i = 0; i < s.length; i++) { assertEquals(s[i], result[i]); } } /** * Test of learn method, of class HMM. */ @Test public void testLearn2() { System.out.println("learn"); EmpiricalDistribution initial = new EmpiricalDistribution(pi); EmpiricalDistribution[] transition = new EmpiricalDistribution[a.length]; for (int i = 0; i < transition.length; i++) { transition[i] = new EmpiricalDistribution(a[i]); } EmpiricalDistribution[] emission = new EmpiricalDistribution[b.length]; for (int i = 0; i < emission.length; i++) { emission[i] = new EmpiricalDistribution(b[i]); } String[] symbols = {"0", "1"}; String[][] sequences = new String[5000][]; int[][] labels = new int[5000][]; for (int i = 0; i < sequences.length; i++) { sequences[i] = new String[30 * (Math.randomInt(5) + 1)]; labels[i] = new int[sequences[i].length]; int state = (int) initial.rand(); sequences[i][0] = symbols[(int) emission[state].rand()]; labels[i][0] = state; for (int j = 1; j < sequences[i].length; j++) { state = (int) transition[state].rand(); sequences[i][j] = symbols[(int) emission[state].rand()]; labels[i][j] = state; } } HMM<String> hmm = new HMM(sequences, labels); System.out.println(hmm); double[] pi2 = {0.55, 0.45}; double[][] a2 = {{0.7, 0.3}, {0.15, 0.85}}; double[][] b2 = {{0.45, 0.55}, {0.3, 0.7}}; HMM<String> init = new HMM<>(pi2, a2, b2, symbols); HMM<String> result = init.learn(sequences, 100); System.out.println(result); } }