/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.examples;
import org.eclipse.jdt.annotation.NonNullByDefault;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.options.BPOptions;
import cern.colt.Arrays;
public class HMM
{
@NonNullByDefault
@SuppressWarnings("null")
public static class TransitionFactorFunction extends FactorFunction
{
@Override
public final double evalEnergy(Value[] args)
{
String state1 = (String)args[0].getObject();
String state2 = (String)args[1].getObject();
double value;
if (state1.equals("sunny"))
{
if (state2.equals("sunny"))
{
value = 0.8;
}
else
{
value = 0.2;
}
}
else
{
value = 0.5;
}
return -Math.log(value);
}
}
@NonNullByDefault
@SuppressWarnings("null")
public static class ObservationFactorFunction extends FactorFunction
{
@Override
public final double evalEnergy(Value[] args)
{
String state = (String)args[0].getObject();
String observation = (String)args[1].getObject();
double value;
if (state.equals("sunny"))
{
if (observation.equals("walk"))
value = 0.7;
else if (observation.equals("book"))
value = 0.1;
else // cook
value = 0.2;
}
else
{
if (observation.equals("walk"))
value = 0.2;
else if (observation.equals("book"))
value = 0.4;
else // cook
value = 0.4;
}
return -Math.log(value);
}
}
@SuppressWarnings("null")
public static void main(String[] args)
{
FactorGraph HMM = new FactorGraph();
DiscreteDomain domain = DiscreteDomain.create("sunny","rainy");
Discrete MondayWeather = new Discrete(domain);
Discrete TuesdayWeather = new Discrete(domain);
Discrete WednesdayWeather = new Discrete(domain);
Discrete ThursdayWeather = new Discrete(domain);
Discrete FridayWeather = new Discrete(domain);
Discrete SaturdayWeather = new Discrete(domain);
Discrete SundayWeather = new Discrete(domain);
TransitionFactorFunction trans = new TransitionFactorFunction();
HMM.addFactor(trans, MondayWeather,TuesdayWeather);
HMM.addFactor(trans, TuesdayWeather,WednesdayWeather);
HMM.addFactor(trans, WednesdayWeather,ThursdayWeather);
HMM.addFactor(trans, ThursdayWeather,FridayWeather);
HMM.addFactor(trans, FridayWeather,SaturdayWeather);
HMM.addFactor(trans, SaturdayWeather,SundayWeather);
ObservationFactorFunction obs = new ObservationFactorFunction();
HMM.addFactor(obs,MondayWeather,"walk");
HMM.addFactor(obs,TuesdayWeather,"walk");
HMM.addFactor(obs,WednesdayWeather,"cook");
HMM.addFactor(obs,ThursdayWeather,"walk");
HMM.addFactor(obs,FridayWeather,"cook");
HMM.addFactor(obs,SaturdayWeather,"book");
HMM.addFactor(obs,SundayWeather,"book");
MondayWeather.setPrior(0.7,0.3);
HMM.setOption(BPOptions.iterations, 20);
HMM.solve();
double [] belief = TuesdayWeather.getBelief();
System.out.println(Arrays.toString(belief));
}
}