/******************************************************************************* * Copyright (C) 2010-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.hmm.latent; import java.util.Collection; import be.ac.ulg.montefiore.run.distributions.GaussianDistribution; import be.ac.ulg.montefiore.run.jahmm.Observation; import be.ac.ulg.montefiore.run.jahmm.ObservationReal; import be.ac.ulg.montefiore.run.jahmm.OpdfGaussian; /** * Hidden semi-Markov model where state transition probabilities depend * on the dwell time, i.e. the number of time steps already spent in a state * @author Dominik Jain */ public abstract class DwellTimeHMM<O extends Observation> implements IDwellTimeHMM<O> { protected double[][] A; protected double[] pi; protected GaussianDistribution[] dwellTimeDist; protected Integer numStates; public DwellTimeHMM(int numStates) { init(numStates); } /** * constructs a completely uninitialized HMM */ public DwellTimeHMM() { numStates = null; } protected void init(int numStates) { if(numStates <= 0) throw new IllegalArgumentException("Number of states must be >= 0."); A = new double[numStates][numStates]; pi = new double[numStates]; dwellTimeDist = new GaussianDistribution[numStates]; this.numStates = numStates; } @Override public Integer getNumStates() { return numStates; } /** * (re-)initializes this HMM for the given number of states * @param numStates * @throws IllegalAccessException */ @Override public void setNumStates(int numStates) throws IllegalAccessException { if(this.numStates != null) throw new IllegalAccessException("Cannot set number of states in model which was constructed with known number of states"); init(numStates); } public double getPi(int state) { return pi[state]; } public double getDwellProbability(int label, int dwellTime) { return 1 - dwellTimeDist[label].cdf(dwellTime); } public double getTransitionProbability(int from, int dwellTime, int to) { double pDwell = 1 - dwellTimeDist[from].cdf(dwellTime); return (1.0-pDwell) * A[from][to]; } public void setA(double[][] A) { this.A = A; } public void setPi(double[] pi) { this.pi = pi; } public void learnDwellTimeDistribution(int state, Collection<ObservationReal> times) { OpdfGaussian pdf = new OpdfGaussian(); if(times.size() == 1) { System.err.println("Only 1 example for length, therefore adding additional items +/- 10"); ObservationReal r = times.iterator().next(); times.add(new ObservationReal(r.value+10)); times.add(new ObservationReal(r.value-10)); } pdf.fit(times); double var = pdf.variance(); dwellTimeDist[state] = new GaussianDistribution(pdf.mean(), var); } }