/******************************************************************************* * Copyright (C) 2007-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed.learning; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Vector; import java.util.Map.Entry; import probcog.srl.Database; import probcog.srl.GenericDatabase; import probcog.srl.Signature; import probcog.srl.ValueDistribution; import probcog.srl.directed.DecisionNode; import probcog.srl.directed.ExtendedNode; import probcog.srl.directed.ParentGrounder; import probcog.srl.directed.RelationalBeliefNetwork; import probcog.srl.directed.RelationalNode; import probcog.srl.directed.ParentGrounder.ParentGrounding; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.ksu.cis.bnj.ver3.core.Discrete; import edu.ksu.cis.bnj.ver3.core.values.ValueDouble; import edu.ksu.cis.bnj.ver3.core.values.ValueZero; import edu.tum.cs.util.StringTool; /** * Learner for the parameters of conditional probability tables of fragments. * @author Dominik Jain */ public class CPTLearner extends probcog.bayesnets.learning.CPTLearner { protected HashMap<Integer, HashMap<String, Integer>> marginals; protected int numCounted, numNotCounted; protected boolean verbose; protected boolean debug = false; public CPTLearner(RelationalBeliefNetwork bn) throws Exception { this(bn, false, false); } public CPTLearner(RelationalBeliefNetwork bn, boolean uniformDefault, boolean debug) throws Exception { super(bn); setUniformDefault(uniformDefault); this.debug = debug; //marginals = new HashMap<Integer, HashMap<String,Integer>>(); // just for debugging } protected void printCountStatus(boolean force) { int total = numCounted+numNotCounted; boolean doPrint = force ? true : total % 10 == 0; if(verbose && doPrint) System.out.printf(" %d/%d counted\r", numCounted, total); } /** * count an example (variable setting) by incrementing the counter for the given variable * @param db the database containing propositions * @param node node of the variable for which we are counting an example * @param params the node's actual parameters * @param closedWorld whether the closed-world assumption is to be made * @throws Exception */ protected void processGrounding(GenericDatabase<?,?> db, RelationalNode node, String[] params, boolean closedWorld) throws Exception { // if the node is not CPT-based, skip it if(!node.hasCPT()) return; // to determine if we really have to count the example, we must // check if there are any decision parents and count only if all // decision parents are true Collection<DecisionNode> decisions = node.getDecisionParents(); if(decisions.size() > 0) { for(DecisionNode decision : decisions) { if(!decision.isTrue(node.params, params, db, closedWorld)) { numNotCounted++; printCountStatus(false); return; } } } RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn; // get the node and its associated counter ExampleCounter counter = this.counters[node.index]; // get the main variable's name String varName = Signature.formatVarName(node.getFunctionName(), params); //System.out.println("counting " + varName); // obtain all groundings of the relevant variables ParentGrounder pg = bn.getParentGrounder(node); Vector<ParentGrounding> groundings = pg.getGroundings(params, db); if(groundings == null) { if(debug) System.err.println("Variable " + Signature.formatVarName(node.getFunctionName(), params)+ " skipped because parents could not be grounded."); return; } //System.out.println(); /*HashMap<String, Integer> counts = marginals.get(node.index); if(counts == null) { counts = new HashMap<String, Integer>(); marginals.put(node.index, counts); }*/ double exampleWeight = 1.0; // do some precomputations to determine example weight /* if(false) { // TODO the code in this block does not yet consider the possibility of decision nodes as parents // - for average of conditional probabilities compute the homogeneity of the relational parents to obtain suitable example weights if(node.aggregator == Aggregator.Average && node.parentMode != null && node.parentMode.equals("CP")) { // create a vector of counts/probabilities // first get the number of configurations that are possible for each parent int dim = 1; Vector<Integer> relevantParentIndices = new Vector<Integer>(); Vector<Integer> precondParentIndices = new Vector<Integer>(); for(RelationalNode parent : bn.getRelationalParents(node)) { if(parent.isPrecondition) { precondParentIndices.add(parent.index); continue; } dim *= parent.getDomain().getOrder(); relevantParentIndices.add(parent.index); } double[] v = new double[dim]; // gather counts int numExamples = 0; for(Map<Integer, String[]> paramSets : groundings) { // for each grounding... boolean skip = false; // check if the preconditions are met for(Integer nodeIdx : precondParentIndices) { RelationalNode ndCurrent = bn.getRelationalNode(nodeIdx); String value = db.getVariableValue(ndCurrent.getVariableName(paramSets.get(ndCurrent.index)), closedWorld); if(!value.equalsIgnoreCase("true")) { skip = true; break; } } if(skip) continue; // count the example int factor = 1; int addr = 0; for(Integer nodeIdx : relevantParentIndices) { RelationalNode ndCurrent = bn.getRelationalNode(nodeIdx); //String value = db.getVariableValue(ndCurrent.getVariableName(paramSets.get(ndCurrent.index)), closedWorld); String value = ndCurrent.getValueInDB(paramSets.get(ndCurrent.index), db, closedWorld); Discrete dom = ndCurrent.getDomain(); int domIdx = dom.findName(value); if(domIdx < 0) { String[] domain = BeliefNetworkEx.getDiscreteDomainAsArray(ndCurrent.node); throw new Exception("Could not find value '" + value + "' in domain of " + ndCurrent.toString() + " {" + StringTool.join(",", domain) + "}"); } addr += factor * domIdx; factor *= dom.getOrder(); } v[addr] += 1; numExamples++; } // obtain probabilities for(int i = 0; i < v.length; i++) v[i] = v[i] / numExamples; // calculate weight exampleWeight = 0; int exponent = 10; for(int i = 0; i < v.length; i++) { exampleWeight += Math.pow(v[i], exponent); } //System.out.println("weight: " + exampleWeight); } } */ // precomputations done... now the actual counting starts // set the domain indices of all relevant nodes (node itself and parents) for(ParentGrounding grounding : groundings) { // for each grounding... Map<Integer, String[]> paramSets = grounding.nodeArgs; // check precondition parents // TODO do we really need this? Preconditions are checked in ParentGrounder? boolean countExample = true; //System.out.println("checking preconditions of grounding of " + node.getVariableName(paramSets.get(node.index))); for(int i = 1; i < counter.nodeIndices.length; i++) { ExtendedNode extCurrent = bn.getExtendedNode(counter.nodeIndices[i]); if(!(extCurrent instanceof RelationalNode)) continue; RelationalNode ndCurrent = (RelationalNode)extCurrent; if(ndCurrent.isPrecondition) { String[] actualParams = paramSets.get(ndCurrent.index); String value = ndCurrent.getValueInDB(actualParams, db, closedWorld); // preconditions are required to be "True" if(!value.equalsIgnoreCase("true")) { countExample = false; break; } } } //System.out.println("checking preconditions done"); if(!countExample) { numNotCounted++; printCountStatus(false); continue; } // if preconditions were met, handle domain indices of all parents // and count the example int domainIndices[] = new int[this.nodes.length]; countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight, 0); numCounted++; if(debug && verbose) { // just debug output StringBuffer condition = new StringBuffer(); for(Entry<Integer, String[]> e : paramSets.entrySet()) { if(e.getKey() == node.index) continue; RelationalNode rn = bn.getRelationalNode(e.getKey()); condition.append(' '); condition.append(rn.getVariableName(e.getValue())); condition.append('='); condition.append(rn.getDomain().getName(domainIndices[rn.index])); } System.out.println(" " + node.getVariableName(params) + "=" + node.getDomain().getName(domainIndices[node.index]) + " |" + condition); } // keep track of counts (just debugging) /*String v = node.node.getDomain().getName(domainIndices[counter.nodeIndices[0]]); Integer i = counts.get(v); if(i == null) i = 0; counts.put(v, i+1);*/ } } /** * helper function that recursively sets the domain indices of parents to learn an entry */ protected void countVariableR(String varName, GenericDatabase<?,?> db, boolean closedWorld, RelationalBeliefNetwork bn, Map<Integer, String[]> paramSets, ExampleCounter counter, int[] domainIndices, double exampleWeight, int i) throws Exception { // count the example if(i == counter.nodeIndices.length) { counter.count(domainIndices, exampleWeight); printCountStatus(false); return; } int domain_idx = -1; ExtendedNode extCurrent = bn.getExtendedNode(counter.nodeIndices[i]); // decision node parents are always true, because we use them to define hard constraints on the use of the CPT we are learning; // whether the constraint that they represent is actually satisfied was checked beforehand if(extCurrent instanceof DecisionNode) { domainIndices[extCurrent.index] = 0; // 0 is true countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight, i+1); } // it's a regular parent else { // get the corresponding RelationalNode object RelationalNode ndCurrent = (RelationalNode)extCurrent; // side affair: learn the CPT of constant nodes here by incrementing the counter if(ndCurrent.isConstant) { String[] actualParams = paramSets.get(ndCurrent.index); domainIndices[ndCurrent.index] = ndCurrent.getDomain().findName(actualParams[0]); this.counters[ndCurrent.index].count(domainIndices); countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight, i+1); } // preconditions were handled above/in ParentGrounder else if(ndCurrent.isPrecondition) { domainIndices[extCurrent.index] = 0; // 0 is true countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight, i+1); } else { // determine the value of the node given the parameter settings implied by the main node String[] actualParams = paramSets.get(ndCurrent.index); if(actualParams == null) { Vector<String> availableNodes = new Vector<String>(); for(Integer idx : paramSets.keySet()) availableNodes.add(idx.toString() + "/" + ndCurrent.getNetwork().getRelationalNode(idx).toString()); throw new Exception("Relevant node " + ndCurrent.index + "/" + ndCurrent + " has no grounding for main node instantiation " + varName + "; have only " + availableNodes.toString()); } Object value = db.getVariableValue(ndCurrent.getVariableName(actualParams), closedWorld); //ndCurrent.getValueInDB(actualParams, db, closedWorld); if(value == null) throw new Exception(String.format("Could not find setting for node named '%s' while processing '%s'", ndCurrent.getName(), varName)); // get the current node's domain and the index of its setting Discrete dom = (Discrete)(ndCurrent.node.getDomain()); if(value instanceof String) { domain_idx = dom.findName((String)value); if(domain_idx == -1) { String[] domElems = new String[dom.getOrder()]; for(int j = 0; j < domElems.length; j++) domElems[j] = dom.getName(j); throw new Exception(String.format("'%s' not found in domain of %s {%s} while processing %s", value, ndCurrent.getVariableName(actualParams), StringTool.join(",", domElems), varName)); } domainIndices[extCurrent.index] = domain_idx; countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight, i+1); } else if(value instanceof ValueDistribution) { ValueDistribution vd = (ValueDistribution)value; for(Entry<String,Double> e : vd.entrySet()) { domain_idx = dom.findName((String)e.getKey()); if(domain_idx == -1) { String[] domElems = new String[dom.getOrder()]; for(int j = 0; j < domElems.length; j++) domElems[j] = dom.getName(j); throw new Exception(String.format("'%s' not found in domain of %s {%s} while processing %s", e.getKey(), ndCurrent.getFunctionName(), StringTool.join(",", domElems), varName)); } domainIndices[extCurrent.index] = domain_idx; double p = e.getValue(); if(p > 0) countVariableR(varName, db, closedWorld, bn, paramSets, counter, domainIndices, exampleWeight * p, i+1); } } } } } /** * learn the CPTs from only the data that is given in the database (relations not in the database are not considered because the closed-world assumption is not being made) * @param db * @throws Exception */ @Deprecated public void learn(Database db) throws Exception { throw new Exception("No longer supported"); /*for(Variable var : db.getEntries()) { countVariable(db, var.nodeName, var.params, false); // TODO: the node used is the one with the most parents that fits }*/ } /** * generates for all nodes all the possible parameters (using the node signatures and domain elements from the database) and counts the corresponding examples * @param db * @param closedWorld * @param verbose * @throws Exception */ public void learnTyped(GenericDatabase<?,?> db, boolean closedWorld, boolean verbose) throws Exception { if(!initialized) init(); this.verbose = verbose; RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn; // construct parent grounders for relevant nodes // (to check early on whether the structure is OK) for(RelationalNode node : bn.getRelationalNodes()) { if(node.isConstant || node.isBuiltInPred() || !node.hasCPT()) continue; node.getParentGrounder(); } // learn CPTs for(RelationalNode node : bn.getRelationalNodes()) { // for each node... // ignore constant nodes as they do not correspond to logical atoms if(node.isConstant || node.isBuiltInPred()) continue; if(verbose) System.out.println(" " + node.getName()); // for precondition nodes, simply set CPT to 100% true if(node.isPrecondition) { CPF cpf = node.node.getCPF(); int numColumns = cpf.getRowLength(); // should be 1 (just in case) ValueDouble v1 = new ValueDouble(1.0); ValueZero zero = new ValueZero(); for(int i = 0; i < numColumns; i++) { cpf.put(i, v1); cpf.put(i+cpf.getColumnValueAddressOffset(), zero); } continue; } // for auxiliary nodes, init to uniform distribution if(node.isAuxiliary) { CPF cpf = node.node.getCPF(); int numRows = cpf.getDomainProduct()[0].getDomain().getOrder(); ValueDouble v = new ValueDouble(1.0 / numRows); for(int i = 0; i < cpf.size(); i++) { cpf.put(i, v); } continue; } numCounted = 0; numNotCounted = 0; // consider all possible bindings for the node's parameters and count String[] params = new String[node.params.length]; processAllGroundings(db, node, params, bn.getSignature(node.getFunctionName()).argTypes, 0, closedWorld); if(verbose) { printCountStatus(true); System.out.println(); } //System.out.println(" counts: " + marginals.get(node.index)); } } /** * generates all groundings (possible lists of parameters) of the node with the given name and counts the corresponding example * @param db the database (containing domains and propositions) to use * @param nodeName name of the node for which to count examples * @param params current list of parameters * @param domainNames list of domain names, with one entry for each parameter * @param i index into params at which to insert the next parameter * @param closedWorld whether to make the closed-world assumption * @throws Exception */ protected void processAllGroundings(GenericDatabase<?,?> db, RelationalNode node, String[] params, String[] domainNames, int i, boolean closedWorld) throws Exception { // if we have the full set of parameters, count the example if(i == params.length) { if(!closedWorld) { String varName = Signature.formatVarName(node.getFunctionName(), params); if(!db.contains(varName)) throw new Exception("Incomplete data: No value for " + varName); } processGrounding(db, node, params, closedWorld); return; } // otherwise consider all ways of extending the current list of parameters // using the domain elements that are applicable if(RelationalNode.isConstant(node.params[i])) { params[i] = node.params[i]; processAllGroundings(db, node, params, domainNames, i+1, closedWorld); } else { Iterable<String> domain = db.getDomain(domainNames[i]); if(domain == null) throw new Exception("Error while grounding " + node + ": Domain " + domainNames[i] + " not found or is empty."); for(String element : domain) { params[i] = element; processAllGroundings(db, node, params, domainNames, i+1, closedWorld); } } } }