/*******************************************************************************
* Copyright (C) 2007-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl;
import java.io.BufferedReader;
import java.io.PrintStream;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.prolog.PrologKnowledgeBase;
import probcog.srl.directed.ABLModel;
import probcog.srl.taxonomy.Concept;
import probcog.srl.taxonomy.Taxonomy;
import edu.tum.cs.util.FileUtil;
import edu.tum.cs.util.StringTool;
import edu.tum.cs.util.datastruct.MultiIterator;
/**
* Base class for evidence/training databases.
* @author Dominik Jain
*/
public abstract class GenericDatabase<VariableType extends AbstractVariable<?>, VarValueType> implements IParameterHandler {
/**
* maps variable names to Variable objects containing values
*/
protected HashMap<String, VariableType> entries;
protected HashMap<RelationKey, HashMap<String, String[]>> functionalDependencies;
protected HashMap<String, HashSet<String>> domains;
public RelationalModel model;
protected PrologKnowledgeBase prolog;
/**
* whether any Prolog value that is computed is always cached by saving it to the database.
* The default value is false, because especially during learning, caching all Prolog values
* may require excessive amounts of memory.
*/
protected boolean cachePrologValues = false;
/**
* true iff the database was extended with all the values that can be computed with the Prolog KB, i.e.
* it is true iff all corresponding variables have been explicitly added to the database
*/
protected Boolean prologDatabaseExtended = false;
protected boolean immutable = false;
// taxonomy-related variables
protected Taxonomy taxonomy;
protected HashMap<String, String> entity2type;
protected HashMap<String, MultiIterator<String>> multiDomains;
protected boolean debug = false;
protected boolean verbose = false;
protected ParameterHandler paramHandler;
/**
* constructs an empty database for the given model
*
* @param model
* @throws Exception
*/
public GenericDatabase(RelationalModel model) throws Exception {
this.model = model;
entries = new HashMap<String, VariableType>();
domains = new HashMap<String, HashSet<String>>();
functionalDependencies = new HashMap<RelationKey, HashMap<String, String[]>>();
taxonomy = model.getTaxonomy();
paramHandler = new ParameterHandler(this);
paramHandler.add("debug", "setDebug");
paramHandler.add("debug", "setVerbose");
// initialize domains
if(taxonomy != null) {
entity2type = new HashMap<String, String>();
multiDomains = new HashMap<String, MultiIterator<String>>();
for(Concept c : model.getTaxonomy().getConcepts()) {
domains.put(c.name, new HashSet<String>());
}
}
// fill domains with guaranteed domain elements
for(Entry<String, ? extends Collection<String>> e : model.getGuaranteedDomainElements().entrySet()) {
for(String element : e.getValue())
fillDomain(e.getKey(), element);
}
Collection<String> prologRules = model.getPrologRules();
prolog = new PrologKnowledgeBase();
if(prologRules != null && !prologRules.isEmpty()) {
System.out.println("building Prolog knowledge base... ");
for(String rule : prologRules) {
try {
System.out.println("telling " + rule);
prolog.tell(rule);
}
catch(Throwable e) {
System.out.println("DID catch");
throw new Exception("Error processing rule '" + rule + "'", e);
}
}
}
}
public void setDebug(boolean debug) {
this.debug = debug;
}
/**
* gets a variable's value as stored in the database
*
* @param varName
* the name of the variable whose value is to be retrieved
* @param closedWorld
* whether to make the closed-world assumption, i.e. to assume
* that any Boolean variable for which we do not have a value is
* "False"
* @return If a value for the given variable is stored in the database (or
* can be computed based on Prolog rules), it is returned.
* Otherwise, null is returned - unless the closed world
* assumption is being made and the variable is boolean, in which
* case the default value of "False" is returned.
* @throws Exception
*/
public abstract VarValueType getVariableValue(String varName, boolean closedWorld) throws Exception;
public abstract String getSingleVariableValue(String varName, boolean closedWorld) throws Exception;
/**
* retrieves a variable setting
*
* @param varName
* the name of the variable to retrieve
* @return returns the variable setting with the given name if it is
* contained in the database), null otherwise
* @deprecated because it does not really work with prolog predicates
*/
public VariableType getVariable(String varName) {
return entries.get(varName.toLowerCase());
}
/**
* checks whether the database contains an entry for the given variable name
*/
public boolean contains(String varName) {
if(entries.containsKey(varName.toLowerCase()))
return true;
//Matcher m = Pattern.compile("(\\w+)\\((\\.*?)\\)").matcher(varName);
// for logically determined functions, we always have a value
String functionName = varName.substring(0, varName.indexOf('(')); //m.group(1);
Signature sig = model.getSignature(functionName);
return sig.isLogical;
}
/**
* adds the given variable to the database if it isn't already present
*/
public boolean addVariable(VariableType var) throws Exception {
return addVariable(var, false, true);
}
protected boolean addVariable(VariableType var, boolean ignoreUndefinedFunctions, boolean doPrologAssertions) throws Exception {
if(immutable)
throw new Exception("Tried to add a value to an immutable database");
boolean ret = false;
String entryKey = var.getKeyString().toLowerCase();
if(entries.containsKey(entryKey))
return ret;
// if(debug) System.out.println("adding var " + var);
// fill domains
Signature sig = model.getSignature(var.functionName);
if(sig == null) {
// if the predicate is not in the model, end here
if(ignoreUndefinedFunctions)
return ret;
else
throw new Exception(String.format("Function %s appears in the data but is not declared in the model.", var.functionName));
}
if(sig.isLogical && doPrologAssertions) { // for logically determined functions, assert any true instances to the Prolog KB
if(var.isTrue()) {
String func = var.functionName;
func = func.substring(0, 1).toLowerCase() + func.substring(1);
String line = func + "(";
for(String par : var.params) {
line += par.substring(0, 1).toLowerCase() + par.substring(1) + ",";
}
line = line.substring(0, line.length() - 1) + ")";
if(debug) System.out.println("Prolog: asserted " + line);
prolog.tell(line + ".");
}
}
if(sig.argTypes.length != var.params.length)
throw new Exception("The database entry '" + var.getKeyString() + "' is not compatible with the signature definition of the corresponding function: expected " + sig.argTypes.length + " parameters as per the signature, got " + var.params.length + ".");
// if(domains.get(sig.returnType) == null || !domains.get(sig.returnType).contains(var.value))
// System.out.println("adding " + var.value + " to " + sig.returnType + " because of " + var);
if(!sig.isBoolean())
fillDomain(sig.returnType, var);
for(int i = 0; i < sig.argTypes.length; i++) {
// if(domains.get(sig.argTypes[i]) == null || !domains.get(sig.argTypes[i]).contains(var.params[i]))
// System.out.println("adding " + var.params[i] + " to " + sig.argTypes[i] + " because of " + var);
fillDomain(sig.argTypes[i], var.params[i]);
}
// add the entry to the main store
entries.put(entryKey, var);
ret = true;
// update lookup tables for keys
// (but only if value is true)
Collection<RelationKey> keys = this.model.getRelationKeys(var.functionName);
if(keys != null) {
// add lookup entry if the variable value is true
if(!var.isTrue())
return ret;
// update all keys
for(RelationKey key : keys) {
// compute key for map entry
StringBuffer sb = new StringBuffer();
int i = 0;
for(Integer paramIdx : key.keyIndices) {
if(i++ > 0)
sb.append(',');
sb.append(var.params[paramIdx]);
}
// add
HashMap<String, String[]> hm = functionalDependencies.get(key);
if(hm == null) {
hm = new HashMap<String, String[]>();
functionalDependencies.put(key, hm);
}
hm.put(sb.toString(), var.params);
}
}
return ret;
}
public abstract void fillDomain(String domName, VariableType var) throws Exception;
public String[] getParameterSet(RelationKey key, String[] keyValues) {
// System.out.println("doing lookup for " + this.key + " with " +
// StringTool.join(", ", keyValues));
HashMap<String, String[]> m = functionalDependencies.get(key);
if(m == null)
return null;
return m.get(StringTool.join(",", keyValues));
}
public void readBLOGDB(String databaseFilename) throws Exception {
readBLOGDB(databaseFilename, false);
}
public void readBLOGDB(String databaseFilename, boolean ignoreUndefinedNodes) throws Exception {
// read file content
if(verbose)
System.out.printf(" reading contents of %s...\n", databaseFilename);
String dbContent = FileUtil.readTextFile(databaseFilename);
// remove comments
if(verbose)
System.out.println(" removing comments");
Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
Matcher matcher = comments.matcher(dbContent);
dbContent = matcher.replaceAll("");
// read lines
if(verbose)
System.out.println(" reading items");
Pattern re_domDecl = Pattern.compile("(\\w+)\\s*=\\s*\\{(.*?)\\}");
BufferedReader br = new BufferedReader(new StringReader(dbContent));
String line;
int numVars = 0;
while((line = br.readLine()) != null) {
line = line.trim();
// parse domain decls
matcher = re_domDecl.matcher(line);
if(matcher.matches()) { // parse domain decls
String domName = matcher.group(1);
String[] constants = matcher.group(2).split("\\s*,\\s*");
constants = ABLModel.makeDomainElements(constants);
for(String c : constants)
fillDomain(domName, c);
continue;
}
// parse variable assignment
VariableType var = readEntry(line);
if(var != null) {
addVariable(var, ignoreUndefinedNodes, true);
if(++numVars % 100 == 0 && verbose)
System.out.print(" " + numVars + " vars read\r");
continue;
}
// something else
if(line.length() != 0) {
throw new Exception("Database entry could not be read: " + line);
}
}
}
protected abstract VariableType readEntry(String line) throws Exception;
/**
* adds to the domain type the given value
*
* @param type
* name of the domain/type
* @param value
* the value/entity name to add
* @throws Exception
*/
public void fillDomain(String type, String value) throws Exception {
// if(debug) System.out.printf(" adding %s to domain %s\n", value, type);
// if we are working with a taxonomy, we need to check whether we
// previously assigned the value to a super-type of type
// and if so, reassign it to the sub-type
if(taxonomy != null) {
String prevType = entity2type.get(value);
if(prevType != null) {
if(prevType.equals(type))
return;
// new type is sub-type --> reassign
if(taxonomy.query_isa(type, prevType))
domains.get(prevType).remove(value);
// new type is supertype --> do nothing (old assignment was more specific)
else if(taxonomy.query_isa(prevType, type))
return;
else
;// System.err.printf("Warning: Entity " + value + " belongs to at least two types (%s, %s) which have no taxonomic relationship; functional mapping of entities to types not well-defined if domains are not merged.");
}
entity2type.put(value, type);
}
// add to domain if not already present
HashSet<String> dom = domains.get(type);
if(dom == null) {
dom = new HashSet<String>();
domains.put(type, dom);
}
if(!dom.contains(value))
dom.add(value);
}
/**
* checks the domains for overlaps and merges domains if necessary
*
*/
public void checkDomains(boolean verbose) {
ArrayList<HashSet<String>> doms = new ArrayList<HashSet<String>>();
ArrayList<String> domNames = new ArrayList<String>();
for(Entry<String, HashSet<String>> entry : domains.entrySet()) {
doms.add(entry.getValue());
domNames.add(entry.getKey());
}
for(int i = 0; i < doms.size(); i++) {
for(int j = i + 1; j < doms.size(); j++) {
// compare the i-th domain to the j-th
HashSet<String> dom1 = doms.get(i);
HashSet<String> dom2 = doms.get(j);
for(String value : dom1) {
if(dom2.contains(value)) { // replace all occurrences of
// the j-th domain by the i-th
if(verbose)
System.out.println("Domains " + domNames.get(i) + " and " + domNames.get(j) + " overlap (both contain " + value + "). Merging...");
String targetDomName = domNames.get(i);
this.model.replaceType(domNames.get(j), targetDomName);
// add all elements of j-th domain to the i-th
dom1.addAll(dom2);
doms.set(j, dom1);
for(String v : dom2)
entity2type.put(v, targetDomName);
break;
}
}
}
}
}
/**
*
* @param domName
* @return the domain as a set of strings or null if the domain is not found
* @throws Exception
*/
public Iterable<String> getDomain(String domName) throws Exception {
if(taxonomy == null)
return domains.get(domName);
else { // if we have a taxonomy, the domain is the combination of domains of the given type and all of its sub-types
MultiIterator<String> dom = multiDomains.get(domName);
if(dom != null)
return dom;
dom = new MultiIterator<String>();
boolean isEmpty = true;
for(Concept c : taxonomy.getDescendants(domName)) {
Iterable<String> subdom = domains.get(c.name);
if(subdom != null) {
dom.add(subdom);
isEmpty = false;
}
}
if(isEmpty)
dom = null;
multiDomains.put(domName, dom);
return dom;
}
}
/**
* retrieves all entries in the database
* @return
* @throws Exception
*/
public Collection<VariableType> getEntries() throws Exception {
finalize();
return entries.values();
}
/**
* If we are using a Prolog KB, extends the database (unless it has already been extended)
* @throws Exception
*/
protected void extendWithPrologValues() throws Exception {
// TODO This does quite a bit of perhaps unnecessary work; it might be better to let Prolog compute just the instances that hold in a single query
if(debug) System.out.println("extending database with Prolog values...");
if(prolog != null && !prologDatabaseExtended) {
prologDatabaseExtended = true;
for(Signature sig : this.model.getSignatures()) {
if(sig.isLogical) {
Collection<String[]> bindings = ParameterGrounder.generateGroundings(sig, this);
for(String[] b : bindings)
getPrologValue(sig, b, true);
}
}
}
}
/**
* makes sure this database is finalized, i.e. all values that can be derived via prolog,
* have been computed and renders the database immutable.
* There is no harm in calling this function several times.
*/
public void finalize() throws Exception {
extendWithPrologValues();
immutable = true;
}
public boolean isFinalized() {
return immutable;
}
/**
* computes the value of a variable via Prolog and adds it to the database
* @param sig
* @param args
* @return
* @throws Exception
*/
protected boolean getPrologValue(Signature sig, String[] args, boolean forceAddToDatabase) throws Exception {
String[] prologArgs = new String[args.length];
for(int j = 0; j < args.length; j++)
prologArgs[j] = args[j].substring(0, 1).toLowerCase() + args[j].substring(1);
boolean value = prolog.ask(Signature.formatVarName(sig.functionName, prologArgs));
VariableType var = makeVar(sig.functionName, args, value ? "True" : "False");
if(cachePrologValues || forceAddToDatabase) {
boolean added = addVariable(var, false, false);
if(added && debug)
System.out.println("Prolog: computed " + var);
}
return value;
}
protected abstract VariableType makeVar(String functionName, String[] args, String value);
/**
* adds all missing values of ground atoms of the given predicate, setting
* them to "False". Invoke <i>after</i> the database has been read!
*
* @param predName
* @throws Exception
*/
public void setClosedWorldPred(String predName) throws Exception {
Signature sig = this.model.getSignature(predName);
if(sig == null)
throw new Exception("Cannot determine signature of " + predName);
String[] params = new String[sig.argTypes.length];
setClosedWorldPred(sig, 0, params);
}
protected void setClosedWorldPred(Signature sig, int i, String[] params) throws Exception {
if(i == params.length) {
String varName = Signature.formatVarName(sig.functionName, params);
if(!this.contains(varName)) {
VariableType var = makeVar(sig.functionName, params.clone(), "False");
this.addVariable(var);
}
return;
}
Iterable<String> dom = this.getDomain(sig.argTypes[i]);
if(dom == null)
return;
for(String value : dom) {
params[i] = value;
setClosedWorldPred(sig, i + 1, params);
}
}
public Signature getSignature(String functionName) {
return model.getSignature(functionName);
}
public void printDomain(PrintStream out) {
for(Entry<String, HashSet<String>> e : domains.entrySet()) {
out.println(e.getKey() + ": " + StringTool.join(", ", e.getValue()));
}
}
public void print() throws Exception {
for(VariableType v : getEntries())
System.out.println(v.toString());
}
/**
*
* @return
*/
public HashMap<String, HashSet<String>> getDomains() throws Exception {
if(taxonomy != null)
throw new Exception("Cannot safely return the set of domains for a model that uses a taxonomy");
return domains;
}
public RelationalModel getModel() {
return this.model;
}
/**
* gets the type of the given constant by searching through the domains
*
* @param constant
* @return the type name or null if the constant is unknown
*/
public String getConstantType(String constant) {
for(Entry<String, HashSet<String>> e : this.domains.entrySet()) {
if(e.getValue().contains(constant)) {
return e.getKey();
}
}
return null;
}
@Override
public ParameterHandler getParameterHandler() {
return paramHandler;
}
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
}