/*
* Author: tdanford
* Date: Dec 3, 2008
*/
package org.seqcode.ml.bayesnets;
import java.util.*;
import org.seqcode.gseutils.models.Model;
import org.seqcode.math.probability.FiniteDistribution;
public class BNCpd {
private BNVar[] parents;
private BNVar child;
private Map<BNValues,FiniteDistribution> cpd;
public BNCpd(BNVar[] pars, BNVar chld) {
parents = pars.clone();
child = chld;
cpd = new LinkedHashMap<BNValues,FiniteDistribution>();
int childSize = child.size();
BNValuesIterator itr = new BNValuesIterator(parents);
while(itr.hasNext()) {
BNValues parvals = itr.next();
cpd.put(parvals, new FiniteDistribution(childSize));
}
}
public void print() {
System.out.println(String.format("CPD: %s", child.getName()));
for(BNValues vals : cpd.keySet()) {
System.out.println(String.format("%s -> %s", vals.toString(), cpd.get(vals).toString()));
}
}
public double logLikelihood(Iterator<Model> obs) {
double sum = 0.0;
while(obs.hasNext()) {
sum += logLikelihood(obs.next());
}
return sum;
}
public double logLikelihood(Model m) {
BNValues parvals = new BNValues(m, parents);
Integer childValue = child.encode(child.findValue(m));
return Math.log(cpd.get(parvals).getProb(childValue));
}
public int countParameters() {
int count = child.size();
for(int i = 0; i < parents.length; i++) {
count *= parents[i].size();
}
return count;
}
public Object sample(BNValues parvalues) {
return child.decode(cpd.get(parvalues).sampleValue());
}
public void resample(Model m) {
Object value = sample(new BNValues(m, parents));
child.setValue(m, value);
}
public void learn(Iterator<? extends Model> obs) {
for(BNValues vals : cpd.keySet()) {
cpd.get(vals).clear();
}
while(obs.hasNext()) {
Model m = obs.next();
BNValues parvals = new BNValues(m, parents);
Integer childValue = child.encode(child.findValue(m));
cpd.get(parvals).addValue(childValue);
}
for(BNValues vals : cpd.keySet()) {
cpd.get(vals).normalize();
}
}
}