/*******************************************************************************
* Copyright (C) 2010-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.hmm;
import java.util.Collection;
import java.util.Vector;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.OpdfFactory;
/**
* A standard hidden Markov model (HMM).
* @author Dominik Jain
*/
public class HMM<O extends Observation> extends Hmm<O> implements IHMM<O> {
private static final long serialVersionUID = 1L;
protected OpdfFactory<? extends Opdf<O>> opdfFactory;
protected Integer numStates = null;
public HMM(int nbStates, OpdfFactory<? extends Opdf<O>> opdfFactory) {
super(nbStates, opdfFactory);
numStates = nbStates;
}
/**
* constructs an HMM where an appropriate number of states is determined during learning
* @param opdfFactory
*/
public HMM(OpdfFactory<? extends Opdf<O>> opdfFactory) {
super();
this.opdfFactory = opdfFactory;
}
public void learnObservationModel(int state, Collection<? extends Collection<? extends O>> data) {
Vector<O> coll = new Vector<O>();
for(Collection<? extends O> segment : data)
coll.addAll(segment);
System.out.printf(" learning observation model for state %d from %d data points...\n", state, coll.size());
this.opdfs.get(state).fit(coll);
}
public void learn(Iterable<? extends SegmentSequence<? extends O>> trainingData, boolean usePseudoCounts) {
TransitionLearner tl = new TransitionLearner(this.numStates, usePseudoCounts);
for(SegmentSequence<? extends O> ss : trainingData) {
O prev = null;
Integer prevLabel = null;
for(Segment<? extends O> seg : ss) {
for(O pt : seg) {
if(prev != null)
tl.learn(prevLabel, seg.label);
prevLabel = seg.label;
prev = pt;
}
}
}
setA(tl.finish());
for(int i = 0; i < numStates; i++) {
Vector<Segment<? extends O>> data = new Vector<Segment<? extends O>>();
for(SegmentSequence<? extends O> ss : trainingData) {
Vector<? extends Segment<? extends O>> segs = ss.getSegments(i);
if(segs == null)
continue;
data.addAll(segs);
}
this.learnObservationModel(i, data);
}
}
@Override
public void setA(double[][] A) {
this.a = A;
}
@Override
public void setPi(double[] pi) {
if(pi.length != numStates)
throw new IllegalArgumentException("Incorrect array length");
this.pi = pi;
}
@Override
public Integer getNumStates() {
return numStates;
}
/**
* sets the number of states for an HMM whose number of states was previously unknown
* @throws IllegalAccessException
*/
@Override
public void setNumStates(int numStates) throws IllegalAccessException {
if(this.numStates != null)
throw new IllegalAccessException("Cannot set number of states in model which was constructed with known number of states");
this.numStates = numStates;
// initialize the HMM
init(numStates, opdfFactory);
}
@Override
public IObservationModel<O> getObservationModel(int state) {
return new OpdfObservationModel<O>(this.opdfs.get(state));
}
}