/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sequencelearning.hmm;
import java.util.Arrays;
import java.util.List;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;
public class HMMUtilsTest extends HMMTestBase {
private Matrix legal22;
private Matrix legal23;
private Matrix legal33;
private Vector legal2;
private Matrix illegal22;
@Override
public void setUp() throws Exception {
super.setUp();
legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
{0.3, 0.3, 0.4}});
legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
{0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
legal2 = new DenseVector(new double[]{0.4, 0.6});
illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
}
@Test
public void testValidatorLegal() {
HmmUtils.validate(new HmmModel(legal22, legal23, legal2));
}
@Test
public void testValidatorDimensionError() {
try {
HmmUtils.validate(new HmmModel(legal33, legal23, legal2));
} catch (IllegalArgumentException e) {
// success
return;
}
fail();
}
@Test
public void testValidatorIllegelMatrixError() {
try {
HmmUtils.validate(new HmmModel(illegal22, legal23, legal2));
} catch (IllegalArgumentException e) {
// success
return;
}
fail();
}
@Test
public void testEncodeStateSequence() {
String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
String[] outputSequence = {"O1", "O2", "O4", "O0"};
// test encoding the hidden Sequence
int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
.asList(hiddenSequence), false, -1);
int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
.asList(outputSequence), true, -1);
// expected state sequences
int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
int[] outputSequenceExp = {1, 2, -1, 0};
// compare
for (int i = 0; i < hiddenSequenceEnc.length; ++i) {
assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
}
for (int i = 0; i < outputSequenceEnc.length; ++i) {
assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
}
}
@Test
public void testDecodeStateSequence() {
int[] hiddenSequence = {1, 2, 0, 3, 10};
int[] outputSequence = {1, 2, 10, 0};
// test encoding the hidden Sequence
List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
getModel(), hiddenSequence, false, "unknown");
List<String> outputSequenceDec = HmmUtils.decodeStateSequence(
getModel(), outputSequence, true, "unknown");
// expected state sequences
String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
// compare
for (int i = 0; i < hiddenSequenceExp.length; ++i) {
assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
}
for (int i = 0; i < outputSequenceExp.length; ++i) {
assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
}
}
@Test
public void testNormalizeModel() {
DenseVector ip = new DenseVector(new double[]{10, 20});
DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
HmmModel model = new HmmModel(tr, em, ip);
HmmUtils.normalizeModel(model);
// the model should be valid now
HmmUtils.validate(model);
}
@Test
public void testTruncateModel() {
DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
DenseMatrix tr = new DenseMatrix(new double[][]{
{0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
{0.0001, 0.0001, 0.9998}});
DenseMatrix em = new DenseMatrix(new double[][]{
{0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
{0.0001, 0.0001, 0.9998}});
HmmModel model = new HmmModel(tr, em, ip);
// now truncate the model
HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
// first make sure this is a valid model
HmmUtils.validate(sparseModel);
// now check whether the values are as expected
Vector sparse_ip = sparseModel.getInitialProbabilities();
Matrix sparse_tr = sparseModel.getTransitionMatrix();
Matrix sparse_em = sparseModel.getEmissionMatrix();
for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON);
for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
if (i == j) {
assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON);
assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON);
} else {
assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON);
assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON);
}
}
}
}
}