package aima.core.probability.bayes.impl; import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import aima.core.probability.CategoricalDistribution; import aima.core.probability.Factor; import aima.core.probability.ProbabilityModel; import aima.core.probability.RandomVariable; import aima.core.probability.bayes.ConditionalProbabilityTable; import aima.core.probability.domain.FiniteDomain; import aima.core.probability.proposition.AssignmentProposition; import aima.core.probability.util.ProbUtil; import aima.core.probability.util.ProbabilityTable; /** * Default implementation of the ConditionalProbabilityTable interface. * * @author Ciaran O'Reilly * */ public class CPT implements ConditionalProbabilityTable { private RandomVariable on = null; private LinkedHashSet<RandomVariable> parents = new LinkedHashSet<RandomVariable>(); private ProbabilityTable table = null; private List<Object> onDomain = new ArrayList<Object>(); public CPT(RandomVariable on, double[] values, RandomVariable... conditionedOn) { this.on = on; if (null == conditionedOn) { conditionedOn = new RandomVariable[0]; } RandomVariable[] tableVars = new RandomVariable[conditionedOn.length + 1]; for (int i = 0; i < conditionedOn.length; i++) { tableVars[i] = conditionedOn[i]; parents.add(conditionedOn[i]); } tableVars[conditionedOn.length] = on; table = new ProbabilityTable(values, tableVars); onDomain.addAll(((FiniteDomain) on.getDomain()).getPossibleValues()); checkEachRowTotalsOne(); } public double probabilityFor(final Object... values) { return table.getValue(values); } // // START-ConditionalProbabilityDistribution @Override public RandomVariable getOn() { return on; } @Override public Set<RandomVariable> getParents() { return parents; } @Override public Set<RandomVariable> getFor() { return table.getFor(); } @Override public boolean contains(RandomVariable rv) { return table.contains(rv); } @Override public double getValue(Object... eventValues) { return table.getValue(eventValues); } @Override public double getValue(AssignmentProposition... eventValues) { return table.getValue(eventValues); } @Override public Object getSample(double probabilityChoice, Object... parentValues) { return ProbUtil.sample(probabilityChoice, on, getConditioningCase(parentValues).getValues()); } @Override public Object getSample(double probabilityChoice, AssignmentProposition... parentValues) { return ProbUtil.sample(probabilityChoice, on, getConditioningCase(parentValues).getValues()); } // END-ConditionalProbabilityDistribution // // // START-ConditionalProbabilityTable @Override public CategoricalDistribution getConditioningCase(Object... parentValues) { if (parentValues.length != parents.size()) { throw new IllegalArgumentException( "The number of parent value arguments [" + parentValues.length + "] is not equal to the number of parents [" + parents.size() + "] for this CPT."); } AssignmentProposition[] aps = new AssignmentProposition[parentValues.length]; int idx = 0; for (RandomVariable parentRV : parents) { aps[idx] = new AssignmentProposition(parentRV, parentValues[idx]); idx++; } return getConditioningCase(aps); } @Override public CategoricalDistribution getConditioningCase( AssignmentProposition... parentValues) { if (parentValues.length != parents.size()) { throw new IllegalArgumentException( "The number of parent value arguments [" + parentValues.length + "] is not equal to the number of parents [" + parents.size() + "] for this CPT."); } final ProbabilityTable cc = new ProbabilityTable(getOn()); ProbabilityTable.Iterator pti = new ProbabilityTable.Iterator() { private int idx = 0; @Override public void iterate(Map<RandomVariable, Object> possibleAssignment, double probability) { cc.getValues()[idx] = probability; idx++; } }; table.iterateOverTable(pti, parentValues); return cc; } public Factor getFactorFor(final AssignmentProposition... evidence) { Set<RandomVariable> fofVars = new LinkedHashSet<RandomVariable>( table.getFor()); for (AssignmentProposition ap : evidence) { fofVars.remove(ap.getTermVariable()); } final ProbabilityTable fof = new ProbabilityTable(fofVars); // Otherwise need to iterate through the table for the // non evidence variables. final Object[] termValues = new Object[fofVars.size()]; ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() { public void iterate(Map<RandomVariable, Object> possibleWorld, double probability) { if (0 == termValues.length) { fof.getValues()[0] += probability; } else { int i = 0; for (RandomVariable rv : fof.getFor()) { termValues[i] = possibleWorld.get(rv); i++; } fof.getValues()[fof.getIndex(termValues)] += probability; } } }; table.iterateOverTable(di, evidence); return fof; } // END-ConditionalProbabilityTable // // // PRIVATE METHODS // private void checkEachRowTotalsOne() { ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() { private int rowSize = onDomain.size(); private int iterateCnt = 0; private double rowProb = 0; public void iterate(Map<RandomVariable, Object> possibleWorld, double probability) { iterateCnt++; rowProb += probability; if (iterateCnt % rowSize == 0) { if (Math.abs(1 - rowProb) > ProbabilityModel.DEFAULT_ROUNDING_THRESHOLD) { throw new IllegalArgumentException("Row " + (iterateCnt / rowSize) + " of CPT does not sum to 1.0."); } rowProb = 0; } } }; table.iterateOverTable(di); } }