/******************************************************************************* * Copyright (C) 2010-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.hmm.latent; import java.util.List; import probcog.hmm.ForwardCalculator; import probcog.hmm.HMM; import probcog.hmm.Segment; import be.ac.ulg.montefiore.run.jahmm.Hmm; import be.ac.ulg.montefiore.run.jahmm.ObservationVector; import be.ac.ulg.montefiore.run.jahmm.Opdf; import be.ac.ulg.montefiore.run.jahmm.OpdfFactory; import be.ac.ulg.montefiore.run.jahmm.OpdfIndependentGaussiansFactory; import be.ac.ulg.montefiore.run.jahmm.learn.BaumWelchScaledLearner; import edu.tum.cs.util.datastruct.ParameterMap; /** * A standard HMM for use as submodel of an LDHMM * @author Dominik Jain */ public class SubHMMSimple extends HMM<ObservationVector> implements ISubHMM { private static final long serialVersionUID = 1L; public SubHMMSimple(int nbStates, int numSubLevels, int obsDimension) { super(nbStates, getOpdfFactory(obsDimension)); } public SubHMMSimple(int numSubLevels, int obsDimension) { super(getOpdfFactory(obsDimension)); } protected static OpdfFactory<? extends Opdf<ObservationVector>> getOpdfFactory(int obsDimension) { //return new OpdfMultiGaussianFactory(obsDimension); return new OpdfIndependentGaussiansFactory(obsDimension); } public void learnViaBaumWelch(List<? extends Segment<? extends ObservationVector>> s) { BaumWelchScaledLearner bw = new BaumWelchScaledLearner(); Hmm<ObservationVector> hmm = bw.learn(this, s); this.pi = hmm.getPi(); this.a = hmm.getA(); this.opdfs = hmm.getOpdfs(); } public void learnViaClustering(Iterable<? extends Segment<? extends ObservationVector>> s, boolean usePseudoCounts) throws Exception { SubHMM.learnViaClustering(this, s, usePseudoCounts); } @Override public double getDwellProbability(int state, int dwellTime) { return a[state][state]; } @Override public double getTransitionProbability(int from, int dwellTime, int to) { return a[from][to]; } @Override public void learn(List<? extends Segment<? extends ObservationVector>> s, ParameterMap learningParams) throws Exception { if(learningParams.getBoolean("learnSubHMMViaBaumWelch")) learnViaBaumWelch(s); else learnViaClustering(s, learningParams.getBoolean("usePseudoCounts")); } public ForwardCalculator<ObservationVector> getForwardCalculator() { return new ForwardCalculator<ObservationVector>(this); } }