package org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl; import java.io.Serializable; import java.util.Enumeration; import java.util.HashMap; import java.util.Map; import java.util.SortedSet; import java.util.TreeSet; import java.util.Vector; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.CountsNotCompleteException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.NodeNotAvailableException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.NodeValueIndexNotInNodeRangeException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.ParentConfigurationNotApplicableException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.ParentsNotContainedException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.RVNotInstantiatedException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.RangeValueNotApplicableException; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.ConditionalProbabilityTable; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.CountTable; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.InstantiatedRV; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.JointMeasurement; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.PriorTable; import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.RandomVariable; public class SimpleCountingLearningEngine implements BayesianProbabilitiesEstimator,Serializable{ private static final long serialVersionUID = -6170201680160868746L; protected Map<RandomVariable,Map<SortedSet<RandomVariable>,CountTable>> countTablesUnderLearning; private Vector<JointMeasurement> measurements; public SimpleCountingLearningEngine() { this.countTablesUnderLearning = new HashMap<RandomVariable,Map<SortedSet<RandomVariable>,CountTable>>(); this.measurements = new Vector<JointMeasurement>(); } public void addMeasurement(JointMeasurement meas) { this.measurements.add(meas); } public void refreshAllLearningTables() throws NodeNotAvailableException{ RandomVariable[] targetNodes = this.countTablesUnderLearning.keySet().toArray(new RandomVariable[0]); for (int i=0;i<targetNodes.length;i++) { // System.out.println("Learning for RV : " + targetNodes[i]); Map<SortedSet<RandomVariable>,CountTable> parentsToCountTableMap = (this.countTablesUnderLearning.get(targetNodes[i])); SortedSet<RandomVariable>[] parentsArray = parentsToCountTableMap.keySet().toArray(new TreeSet/*<RandomVariable>*/[0]); for (int j=0;j<parentsArray.length;j++) { this.computeCountTable(targetNodes[i], parentsArray[j]); } } } /** * Computes the specified currently requested count table for targetNode and parents, based on the measurements. * @param targetNode * @param parents * @throws NodeNotAvailableException */ private void computeCountTable(RandomVariable targetNode, SortedSet<RandomVariable>parents) throws NodeNotAvailableException { CountTable target_ct = (CountTable) (this.countTablesUnderLearning.get(targetNode)).get(parents); Enumeration<JointMeasurement> enumer = this.measurements.elements(); while (enumer.hasMoreElements()) { JointMeasurement jm = enumer.nextElement(); // System.out.println("... using measurement : " + jm); InstantiatedRV rvnvp_target = (InstantiatedRV) jm.getInstantiatedRV().get(targetNode); if (rvnvp_target == null) { throw new NodeNotAvailableException("\nNode: " + targetNode + " not in JointMeasurement: " + jm); } int targetNode_value_target; if (!rvnvp_target.isMissingInstantiation()) { int targetNode_value_index = 0; try { targetNode_value_target = rvnvp_target.getRVValue(); targetNode_value_index = targetNode.getNodeRangePositionFromValue(targetNode_value_target); } catch (RVNotInstantiatedException e1) { System.err.println("In computeCountTable() catch (RVNotInstantiatedException e); this should not happen."); e1.printStackTrace(); } catch (NodeValueIndexNotInNodeRangeException e) { System.err.println("In computeCountTable() catch (NodeValueIndexNotInNodeRangeException e); this should not happen."); e.printStackTrace(); } try { int parentConfiguration = target_ct.computeParentConfiguration(jm, true); if (parentConfiguration>=0) { target_ct.incrementCount(parentConfiguration, targetNode_value_index + 1); } else { System.err.println("In computeCountTables(): found a missing RV value for some parent of RV: " + targetNode); } } catch (ParentConfigurationNotApplicableException e) { e.printStackTrace(); } catch (RangeValueNotApplicableException e) { e.printStackTrace(); } catch (ParentsNotContainedException e) { // actually, as long as // argument "true" in // computeParentConfiguration() above, // this will not occur! System.err.println("In computeCountTables(), catch (ParentsNotContainedException e): This should not happen"); e.printStackTrace(); } } } target_ct.setCounted(true); } /** * Adds any new internal nodes to the passed countTable needed to incorporate the current Learning request. * Here, this means checking that the CountTables hold all nodes that are * in the parents of node i or the node i itself. * @param node_i * @param parents_of_node_i */ private void extendTables(RandomVariable node_i, SortedSet<RandomVariable> parents_of_node_i) { Map<RandomVariable,Map<SortedSet<RandomVariable>,CountTable>> countTableMapMap; countTableMapMap = this.countTablesUnderLearning; Map<SortedSet<RandomVariable>,CountTable> ct_mapSortedSet_node_i = null; if (!countTableMapMap.containsKey(node_i)) { ct_mapSortedSet_node_i = new HashMap<SortedSet<RandomVariable>,CountTable>(); countTableMapMap.put(node_i, ct_mapSortedSet_node_i); } else { ct_mapSortedSet_node_i = countTableMapMap.get(node_i); } this.updateParentMembership(node_i, ct_mapSortedSet_node_i, parents_of_node_i); } /** * Make sure that the Map of SortedSet of RandomVariable to CountTable, ct_mapSortedSet_node_i, * contains a SortedSet of RandomVariables with exactly all parents_of_node_i. * If not, then make a new counting table and store it in ct_mapSortedSet_node_i. * @param node_i * @param ct_mapSortedSet_node_i * @param parents_of_node_i */ private void updateParentMembership(RandomVariable node_i, Map<SortedSet<RandomVariable>,CountTable> ct_mapSortedSet_node_i, SortedSet<RandomVariable> parents_of_node_i) { // System.out.println("ZZZ updateParentMembership for " + node_i + ". Parents: " + parents_of_node_i + " ct_mapSortedSet_node_i " + ct_mapSortedSet_node_i); if (ct_mapSortedSet_node_i.containsKey(parents_of_node_i)) return; CountTable newcountTable = null; newcountTable = createNewCountingTable(node_i, parents_of_node_i); ct_mapSortedSet_node_i.put(new TreeSet<RandomVariable>(parents_of_node_i), newcountTable); } /** * @param node_i * @param parents_of_node_i * @return */ private CountTable createNewCountingTable(RandomVariable node_i, SortedSet<RandomVariable> parents_of_node_i) { CountTable ct_node_i = new NaiveCountTable(node_i, parents_of_node_i); return ct_node_i; } /** * @param node_i * @param parents_of_node_i * @return */ private PriorTable createNewPriorTable(int n_equiv, RandomVariable node_i, SortedSet<RandomVariable> parents_of_node_i) { PriorTable pt_node_i = new NaivePriorTable(n_equiv, node_i, parents_of_node_i); return pt_node_i; } /* (non-Javadoc) * @see eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator#getCounts(eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.RandomVariable, java.util.SortedSet) */ public CountTable getCounts(RandomVariable targetNode, SortedSet<RandomVariable> parentNodes) { this.extendTables(targetNode, parentNodes); Map<SortedSet<RandomVariable>,CountTable> ct_mapSortedSet_node_i = this.countTablesUnderLearning.get(targetNode); if (!(ct_mapSortedSet_node_i.get(parentNodes)).isCounted()) { try { this.computeCountTable(targetNode, parentNodes); } catch (NodeNotAvailableException e) { e.printStackTrace(); } } return ct_mapSortedSet_node_i.get(parentNodes); } /* (non-Javadoc) * @see eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator#getCPT(eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.RandomVariable, java.util.SortedSet) */ public ConditionalProbabilityTable getCPT(RandomVariable targetNode, SortedSet<RandomVariable> parentNodes, PriorTable alphas) throws CountsNotCompleteException{ ConditionalProbabilityTable cpt = this.computeCPT(targetNode, parentNodes, this.getCounts(targetNode, parentNodes), alphas); return cpt; } /* (non-Javadoc) * @see eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator#computeCPT(eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.RandomVariable, java.util.SortedSet, eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.CountTable, eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.PriorTable) */ public ConditionalProbabilityTable computeCPT(RandomVariable targetNode, SortedSet<RandomVariable> parentNodes, CountTable counts, PriorTable alphas) throws CountsNotCompleteException{ ConditionalProbabilityTable cpt = new CountingCPT(this.getCounts(targetNode, parentNodes), alphas); return cpt; } /* (non-Javadoc) * @see eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator#resetTables() */ public void resetTables() { this.countTablesUnderLearning.clear(); // TODO: call any registered consumers that want to listen to this event. Eg cache in BNCandidate local fitness } /* (non-Javadoc) * @see eu.ist.daidalos.pervasive.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator#clearMeasurements() */ public void clearMeasurements() { this.measurements.clear(); } @Override public PriorTable getUniformPriors(int n_equiv, RandomVariable rv, SortedSet<RandomVariable> parents_of_node_i) { return this.createNewPriorTable(n_equiv, rv, parents_of_node_i); } public String toString() { return ("Simple Counting LE " + this.measurements.size()); } public void setCurrentBayesianNetworkStructure(BayesianNetworkCandidate bnc) { // Not needed here! } }