package org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl;
import java.util.Set;
import java.util.SortedSet;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.NodeValueIndexNotInNodeRangeException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.ParentConfigurationNotApplicableException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.ParentsNotContainedException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.PriorAndCountTablesMismatchException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.RVNotInstantiatedException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.RangeValueNotApplicableException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.ConditionalProbabilityTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.CountTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.InstantiatedRV;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.JointMeasurement;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.PriorTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.RandomVariable;
public class CountingCPT implements ConditionalProbabilityTable{
private CountTable ct;
private PriorTable pt;
public CountingCPT(CountTable ct, PriorTable pt){
this.ct = ct;
this.pt = pt;
}
private void checkCompatibilityCTandPT() throws PriorAndCountTablesMismatchException {
if (this.ct==null) {
throw new NullPointerException("\nCountTable is null");
}
if (this.pt==null) {
throw new NullPointerException("\nPriorTable is null");
}
if (!this.ct.getTargetRV().equals(this.pt.getTargetRV())) {
System.err.println("\nct: " + this.ct.getTargetRV().getName() + " pt: " + this.pt.getTargetRV().getName());
throw new PriorAndCountTablesMismatchException("\nTarget node of CountTable " +
this.ct.getTargetRV() + " does not match Target node of Prior Table " + this.pt.getTargetRV());
}
if (!this.ct.getOrderedParents().equals(this.pt.getOrderedParents())) {
System.err.println("\nct: " + this.ct.getOrderedParents() + " pt: " + this.pt.getOrderedParents());
throw new PriorAndCountTablesMismatchException("\nOrdered Parents of CountTable " +
this.ct.getOrderedParents() + " does not match Ordered Parents of Prior Table " + this.pt.getOrderedParents());
}
}
public double[][] getProbabilityTable() throws PriorAndCountTablesMismatchException {
this.checkCompatibilityCTandPT();
double[][] cpTable = new double[this.ct.getK_max()+1][this.ct.getJ_max()+1];
this.fillTable(cpTable);
return cpTable;
}
private void fillTable(double[][] cpTable) {
for (int k=1;k<=this.ct.getK_max();k++) {
for (int j=1;j<=this.ct.getJ_max();j++) {
try {
cpTable[k][j] = this.computeTableValue(j, k);
} catch (ParentConfigurationNotApplicableException e) {
System.err.println("In fillTables() catch (ParentConfigurationNotApplicableException e); this should not happen.");
e.printStackTrace();
} catch (RangeValueNotApplicableException e) {
System.err.println("In fillTables() catch (RangeValueNotApplicableException e); this should not happen.");
e.printStackTrace();
}
}
}
}
public double getProbability(int parentConfiguration, int node_i_range_value_index_k)
throws ParentConfigurationNotApplicableException, RangeValueNotApplicableException, PriorAndCountTablesMismatchException {
this.checkCompatibilityCTandPT();
return computeTableValue(parentConfiguration, node_i_range_value_index_k);
}
private double computeTableValue(int parentConfiguration, int node_i_range_value_index_k)
throws ParentConfigurationNotApplicableException, RangeValueNotApplicableException {
int n_ijk = this.ct.getCount(parentConfiguration, node_i_range_value_index_k);
int n_ij = this.ct.getCount(parentConfiguration);
double alpha_ijk = this.pt.getVirtualCount(parentConfiguration, node_i_range_value_index_k);
double alpha_ij = this.pt.getVirtualCount(parentConfiguration);
return (alpha_ijk + (double) n_ijk)/(alpha_ij+ (double)n_ij);
}
public double getProbability(Set<InstantiatedRV> parents, int node_i_range_value)
throws NodeValueIndexNotInNodeRangeException, ParentsNotContainedException,
PriorAndCountTablesMismatchException, RVNotInstantiatedException {
try {
return this.getProbability(this.computeParentConfiguration(parents, false),
this.getTargetRV().getNodeRangePositionFromValue(node_i_range_value)+1);
} catch (ParentConfigurationNotApplicableException e) {
System.err.println("In getProbability() catch (ParentConfigurationNotApplicableException e); this should not happen.");
e.printStackTrace();
return 0.0;
} catch (RangeValueNotApplicableException e) {
e.printStackTrace();
throw new NodeValueIndexNotInNodeRangeException(e);
}
}
public boolean containsExactlyAllParents(Set<RandomVariable> test_set) {
return this.ct.containsExactlyAllParents(test_set);
}
public RandomVariable getTargetRV() {
return this.ct.getTargetRV();
}
public int computeParentConfiguration(JointMeasurement instantiatedMeasurements, boolean replaceMissingParentsWithNOOBs) throws ParentsNotContainedException {
return this.ct.computeParentConfiguration(instantiatedMeasurements, replaceMissingParentsWithNOOBs);
}
public int computeParentConfiguration(Set<InstantiatedRV> instantiatedMeasurements, boolean replaceMissingParentsWithNOOBs)
throws ParentsNotContainedException {
return this.ct.computeParentConfiguration(instantiatedMeasurements, replaceMissingParentsWithNOOBs);
}
public SortedSet<RandomVariable> getOrderedParents() {
return this.ct.getOrderedParents();
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RVwithParents#getK_max()
*/
public int getK_max() {
return this.ct.getK_max();
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RVwithParents#getJ_max()
*/
public int getJ_max() {
return this.ct.getJ_max();
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RVwithParents#addParent(de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RandomVariable)
*/
public void addParent(RandomVariable parent) {
this.ct.addParent(parent);
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RVwithParents#addParent(de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RandomVariable)
*/
public void removeParent(RandomVariable parent) {
this.ct.removeParent(parent);
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RVwithParents#addParent(de.kl.kn.bayesianLibrary.bayesianLearner.interfaces.RandomVariable)
*/
public void removeAllParents() {
this.ct.removeAllParents();
}
public String toString() {
StringBuffer ob = new StringBuffer();
ob.append("Probability Table for ");
ob.append(this.getTargetRV());
ob.append(". With these parents: \n");
RandomVariable[] parents_array = (RandomVariable[]) this.getOrderedParents().toArray(new RandomVariable[0]);
for (int p=0;p<this.getOrderedParents().size();p++) {
ob.append(parents_array[p]+" \n");
}
try {
this.checkCompatibilityCTandPT();
} catch (PriorAndCountTablesMismatchException e1) {
e1.printStackTrace();
ob.append("error: " + e1.getMessage());
return ob.toString();
}
ob.append(" \nProbabilities (" + this.ct.getJ_max() + "):");
for (int k=1;k<=this.ct.getK_max();k++) {
ob.append("\nk: " + k+"||");
for (int j=1;j<=this.ct.getJ_max();j++) {
try {
double prob = this.computeTableValue(j, k);
ob.append("j:"+j+"\\c="+(prob)+"|");
} catch (ParentConfigurationNotApplicableException e) {
System.err.println("In fillTables() catch (ParentConfigurationNotApplicableException e); this should not happen.");
e.printStackTrace();
} catch (RangeValueNotApplicableException e) {
System.err.println("In fillTables() catch (RangeValueNotApplicableException e); this should not happen.");
e.printStackTrace();
}
}
try {
ob.append("|" + this.getTargetRV().getNodeValueText(this.getTargetRV().getNodeRange()[k-1]));
} catch (NodeValueIndexNotInNodeRangeException e) {
e.printStackTrace();
}
}
ob.append("\n-----------------------------------------------------");
return ob.toString();
}
}