/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.hmm.train.kmeans; import java.util.Collection; import java.util.List; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.MLSequenceSet; import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.ml.hmm.HiddenMarkovModel; import org.encog.ml.hmm.alog.ViterbiCalculator; import org.encog.ml.hmm.distributions.StateDistribution; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.Strategy; import org.encog.neural.networks.training.propagation.TrainingContinuation; /** * Train a Hidden Markov Model (HMM) with the KMeans algorithm. Makes use of * KMeans clustering to estimate the transitional and observational * probabilities for the HMM. * * Unlike Baum Welch training, this method does not require a prior estimate of * the HMM model, it starts from scratch. * * Faber, Clustering and the Continuous k-Means Algorithm, Los Alamos Science, * no. 22, 1994. */ public class TrainKMeans implements MLTrain { private final Clusters clusters; private final int states; private final MLSequenceSet sequnces; private boolean done; private final HiddenMarkovModel modelHMM; private int iteration; private HiddenMarkovModel method; private final MLSequenceSet training; public TrainKMeans(final HiddenMarkovModel method, final MLSequenceSet sequences) { this.method = method; this.modelHMM = method; this.sequnces = sequences; this.states = method.getStateCount(); this.training = sequences; this.clusters = new Clusters(this.states, sequences); this.done = false; } @Override public void addStrategy(final Strategy strategy) { } @Override public boolean canContinue() { return false; } @Override public void finishTraining() { } @Override public double getError() { return this.done ? 0 : 100; } @Override public TrainingImplementationType getImplementationType() { return TrainingImplementationType.Iterative; } @Override public int getIteration() { return this.iteration; } @Override public MLMethod getMethod() { return this.method; } @Override public List<Strategy> getStrategies() { return null; } @Override public MLDataSet getTraining() { return this.training; } @Override public boolean isTrainingDone() { return this.done; } @Override public void iteration() { final HiddenMarkovModel hmm = this.modelHMM.cloneStructure(); learnPi(hmm); learnTransition(hmm); learnOpdf(hmm); this.done = optimizeCluster(hmm); this.method = hmm; } @Override public void iteration(final int count) { // this.iteration = count; } private void learnOpdf(final HiddenMarkovModel hmm) { for (int i = 0; i < hmm.getStateCount(); i++) { final Collection<MLDataPair> clusterObservations = this.clusters .cluster(i); if (clusterObservations.size() < 1) { final StateDistribution o = this.modelHMM .createNewDistribution(); hmm.setStateDistribution(i, o); } else { final MLDataSet temp = new BasicMLDataSet(); for (final MLDataPair pair : clusterObservations) { temp.add(pair); } hmm.getStateDistribution(i).fit(temp); } } } private void learnPi(final HiddenMarkovModel hmm) { final double[] pi = new double[this.states]; for (int i = 0; i < this.states; i++) { pi[i] = 0.; } for (final MLDataSet sequence : this.sequnces.getSequences()) { pi[this.clusters.cluster(sequence.get(0))]++; } for (int i = 0; i < this.states; i++) { hmm.setPi(i, pi[i] / this.sequnces.size()); } } private void learnTransition(final HiddenMarkovModel hmm) { for (int i = 0; i < hmm.getStateCount(); i++) { for (int j = 0; j < hmm.getStateCount(); j++) { hmm.setTransitionProbability(i, j, 0.); } } for (final MLDataSet obsSeq : this.sequnces.getSequences()) { if (obsSeq.size() < 2) { continue; } int first_state; int second_state = this.clusters.cluster(obsSeq.get(0)); for (int i = 1; i < obsSeq.size(); i++) { first_state = second_state; second_state = this.clusters.cluster(obsSeq.get(i)); hmm.setTransitionProbability( first_state, second_state, hmm.getTransitionProbability(first_state, second_state) + 1.); } } /* Normalize Aij array */ for (int i = 0; i < hmm.getStateCount(); i++) { double sum = 0; for (int j = 0; j < hmm.getStateCount(); j++) { sum += hmm.getTransitionProbability(i, j); } if (sum == 0.) { for (int j = 0; j < hmm.getStateCount(); j++) { hmm.setTransitionProbability(i, j, 1. / hmm.getStateCount()); } } else { for (int j = 0; j < hmm.getStateCount(); j++) { hmm.setTransitionProbability(i, j, hmm.getTransitionProbability(i, j) / sum); } } } } private boolean optimizeCluster(final HiddenMarkovModel hmm) { boolean modif = false; for (final MLDataSet obsSeq : this.sequnces.getSequences()) { final ViterbiCalculator vc = new ViterbiCalculator(obsSeq, hmm); final int states[] = vc.stateSequence(); for (int i = 0; i < states.length; i++) { final MLDataPair o = obsSeq.get(i); if (this.clusters.cluster(o) != states[i]) { modif = true; this.clusters.remove(o, this.clusters.cluster(o)); this.clusters.put(o, states[i]); } } } return !modif; } @Override public TrainingContinuation pause() { return null; } @Override public void resume(final TrainingContinuation state) { } @Override public void setError(final double error) { } @Override public void setIteration(final int iteration) { this.iteration = iteration; } }