/******************************************************************************* * Copyright (C) 2006-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srldb; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.PrintStream; import java.io.Serializable; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Vector; import probcog.clustering.BasicClusterer; import probcog.clustering.ClusterNamer; import probcog.srl.directed.ABLModel; import probcog.srldb.datadict.AutomaticDataDictionary; import probcog.srldb.datadict.DDAttribute; import probcog.srldb.datadict.DDException; import probcog.srldb.datadict.DDItem; import probcog.srldb.datadict.DataDictionary; import probcog.srldb.datadict.domain.Domain; import probcog.srldb.datadict.domain.OrderedStringDomain; import edu.tum.cs.util.datastruct.MultiIterator; /** * Represents a relational database. * @author Dominik Jain */ public class Database implements Cloneable, Serializable { private static final long serialVersionUID = 1L; protected HashSet<Link> links; protected HashSet<Object> objects; protected HashMap<String, HashMap<String, Object>> constantObjectsByType; protected DataDictionary datadict; /** * creates a relational database * @param dd the data dictionary that this database must conform to */ public Database(DataDictionary dd) { links = new HashSet<Link>(); objects = new HashSet<Object>(); constantObjectsByType = new HashMap<String, HashMap<String,Object>>(); datadict = dd; } /** * creates a relational database with an automatically generated data dictionary */ public Database() { this(new AutomaticDataDictionary()); } /** * performs clustering on an attribute that is defined for objects in the vector objects; The number of clusters is determined by the number of names in clusterNames * @param attribute the attribute whose values are to be clustered * @param objects a vector of objects, some of which (but not necessarily all) must have the attribute * @param clusterer a clusterer used to perform the clustering * @param clusterNamer a namer for the resulting clusters, which is used to redefine the attribute's domain and to update all the attribute values * @throws DDException if problems with data dictionary conformity are discovered * @throws Exception if there are no instances of the attribute, i.e. the attribute is undefines for all objects */ public static AttributeClustering clusterAttribute(DDAttribute attribute, Iterable<Item> objects, BasicClusterer<? extends weka.clusterers.Clusterer> clusterer, ClusterNamer<weka.clusterers.Clusterer> clusterNamer) throws DDException, Exception { String attrName = attribute.getName(); // create clusterer and collect instances int instances = 0; for(Item obj : objects) { String value = obj.getAttributeValue(attrName); if(value == null) continue; clusterer.addInstance(Double.parseDouble(value)); instances++; } // build clusterer clusterer.buildClusterer(); // get cluster names String[] clusterNames = clusterNamer.getNames(clusterer.getWekaClusterer()); // check number of instances if(instances < clusterNames.length) { System.err.println("Warning: attribute " + attrName + " was discarded because there are too few instances for clustering"); attribute.discard(); return null; } if(instances == 0) throw new Exception("The domain is empty; No instances could be clustered for attribute " + attrName); // apply cluster assignment to attribute values AttributeClustering ac = new AttributeClustering(); ac.clusterer = clusterer; ac.newDomain = new OrderedStringDomain(attribute.getDomain().getName(), clusterNames); applyClustering(attribute, objects, ac); return ac; } public static void applyClustering(DDAttribute attribute, Iterable<Item> objects, AttributeClustering ac) throws NumberFormatException, Exception { // apply cluster assignment to attribute values String attrName = attribute.getName(); for(Item obj : objects) { String value = obj.attribs.get(attrName); if(value != null) { int i = ac.clusterer.classify(Double.parseDouble(value)); String svalue = ac.newDomain.getValues()[i]; obj.attribs.put(attrName, svalue); /*if(attrName.equals("radDistRatio")) { Object o = ((Object)obj.getLink("isEllipseOf").getArguments()[1]); System.out.printf(" %s %-10s %s -> %s\n", o.getConstantName(), o.getString("objectT"), value, svalue); }*/ } } // redefine attribute domain attribute.setDomain(ac.newDomain); } public void writeMLNDatabase(PrintStream out) throws Exception { out.println("// *** mln database ***\n"); // links out.println("// links"); for(Link link : links) link.MLNprintFacts(out); // objects Counters cnt = new Counters(); for(Object obj : objects) { out.println("// " + obj.objType() + " #" + cnt.inc(obj.objType())); obj.MLNprintFacts(out); } } public void writeBLOGDatabase(PrintStream out) throws Exception { // check function names for(DDAttribute ddattr : this.getDataDictionary().getAttributes()) { if(ddattr.isDiscarded()) continue; if(!ABLModel.isValidFunctionName(ddattr.getName())) throw new Exception("'" + ddattr.getName() + "' is not a valid function name"); } // write all facts for(Object obj : objects) { obj.BLOGprintFacts(out); } for(Link link : links) { link.BLOGprintFacts(out); } } /** * outputs the basic MLN for this database, which contains domain definitions and predicate declarations * @param out the stream to write to */ public void writeBasicMLN(PrintStream out) { datadict.writeBasicMLN(out); } /** * writes this database object to a file (and, as a side-effect, changes all items' database references to this object in order to avoid saving data on other databases) * @param s * @throws IOException */ public void writeSRLDB(FileOutputStream s) throws IOException { // make sure that we do not by mistake save data on other databases // by setting all the items' databases to this for(Object o : this.objects) o.database = this; for(Link l : this.links) l.database = this; // clean up data dictionary this.datadict.cleanUp(); // save ObjectOutputStream objstream = new ObjectOutputStream(s); objstream.writeObject(this); objstream.close(); } /** * reads a previously stored database object from a file * @param s * @return * @throws IOException * @throws ClassNotFoundException */ public static Database fromFile(FileInputStream s) throws IOException, ClassNotFoundException { ObjectInputStream objstream = new ObjectInputStream(s); java.lang.Object object = objstream.readObject(); objstream.close(); return (Database)object; } protected String printParams(IRelationArgument[] arguments) { StringBuffer linkParams = new StringBuffer(); for(int i = 0; i < arguments.length; i++) { if(i > 0) linkParams.append(","); linkParams.append(Database.upperCaseString(arguments[i].getConstantName())); } return linkParams.toString(); } /** * Turns all Links with more than 2 arguments into 2 argument links by introducing a dummy object and linking all the * arguments to it. * @throws DDException */ public void flattenLinks() throws DDException { Integer n = 0; HashSet<Link> newLinks= new HashSet<Link>(); for(Iterator<Link> it = links.iterator(); it.hasNext();) { Link l = it.next(); if(l.getArguments().length > 2) { //Object obj = new Object(this,l.getName()+printParams(l.getArguments())+"_obj"); Object obj = new Object(this,l.toString()+"_obj"); addObject(obj); Object[] args = l.getArgumentObjects(); for(Integer i =0;i<args.length;i++){ Link newLink = new Link(this,l.getName()+"_"+i.toString(),args[i],obj); //addLink(newLink); newLinks.add(newLink); } it.remove(); n++; } } links.addAll(newLinks); } /** * outputs the data contained in this database to an XML database file for use with Proximity * @param out the stream to write to * @throws Exception */ public void writeProximityDatabase(java.io.PrintStream out) throws Exception { flattenLinks(); System.out.println("\n" + getDataDictionary()); out.println("<?xml version=\"1.0\" encoding=\"UTF-8\"?>"); out.println("<!DOCTYPE PROX3DB SYSTEM \"prox3db.dtd\">"); out.println("<PROX3DB>"); // objects System.out.println("writing objects..."); out.println(" <OBJECTS>"); for(Object obj : objects) { out.println(" <OBJECT ID=\"" + obj.id + "\"/>"); } out.println(" </OBJECTS>"); // links System.out.println("writing links..."); out.println(" <LINKS>"); HashSet<String> warnedLinks = new HashSet<String>(); for(Link link : links) { if(link.getArguments().length != 2) { if(!warnedLinks.contains(link.getName())) { System.err.println("Warning: non-binary link/relation found: " + link.getName() + " - using first two objects only"); warnedLinks.add(link.getName()); } } Object[] args = link.getArgumentObjects(); Object o1 = args[0]; Object o2 = args[1]; out.println(" <LINK ID=\"" + link.id + "\" O1-ID=\"" + o1.id + "\" O2-ID=\"" + o2.id + "\"/>"); } out.println(" </LINKS>"); // attributes System.out.println("writing attributes..."); out.println(" <ATTRIBUTES>"); // - regular attributes for(DDAttribute attrib : datadict.getAttributes()) { if(attrib.isDiscarded()) continue; DDItem owner = attrib.getOwner(); if(owner == null) // skip attributes without owner (i.e. dummy attributes that are created for constant relation arguments) continue; String attribName = attrib.getName(); System.out.println(" attribute " + attribName); out.println(" <ATTRIBUTE NAME=\"" + Database.stdAttribName(attribName) + "\" ITEM-TYPE=\"" + (attrib.getOwner().isObject() ? "O" : "L") + "\" DATA-TYPE=\"" + attrib.getType() + "\">"); Iterator<? extends Item> iItem = owner.isObject() ? objects.iterator() : links.iterator(); while(iItem.hasNext()) { Item item = (Item) iItem.next(); if(item.hasAttribute(attribName)) { out.println(" <ATTR-VALUE ITEM-ID=\"" + item.id + "\">"); out.println(" <COL-VALUE>" + Database.stdAttribStringValue(item.attribs.get(attribName)) + "</COL-VALUE></ATTR-VALUE>"); } } out.println(" </ATTRIBUTE>"); } // - special attribute objtype for objects out.println(" <ATTRIBUTE NAME=\"objtype\" ITEM-TYPE=\"O\" DATA-TYPE=\"str\">"); for(Object obj : objects) { out.println(" <ATTR-VALUE ITEM-ID=\"" + obj.id + "\">"); out.println(" <COL-VALUE>" + obj.objType() + "</COL-VALUE></ATTR-VALUE>"); } out.println(" </ATTRIBUTE>"); // - special attribute constantName for objects out.println(" <ATTRIBUTE NAME=\"constName\" ITEM-TYPE=\"O\" DATA-TYPE=\"str\">"); for(Object obj : objects) { out.println(" <ATTR-VALUE ITEM-ID=\"" + obj.id + "\">"); out.println(" <COL-VALUE>" + upperCaseString(obj.getConstantName()) + "</COL-VALUE></ATTR-VALUE>"); } out.println(" </ATTRIBUTE>"); // - special attribute link_tag for links out.println(" <ATTRIBUTE NAME=\"link_tag\" ITEM-TYPE=\"L\" DATA-TYPE=\"str\">"); for(Link link : links) { out.println(" <ATTR-VALUE ITEM-ID=\"" + link.id + "\">"); out.println(" <COL-VALUE>" + link.getName() + "</COL-VALUE></ATTR-VALUE>"); } out.println(" </ATTRIBUTE>"); out.println(" </ATTRIBUTES>"); // done out.println("</PROX3DB>"); } /** * returns a string where the first letter is lower case * @param s the string to convert * @return the string s with the first letter converted to lower case */ public static String lowerCaseString(String s) { char[] c = s.toCharArray(); c[0] = Character.toLowerCase(c[0]); return new String(c); } /** * returns a string where the first letter is upper case * @param s the string to convert * @return the string s with the first letter converted to upper case */ public static String upperCaseString(String s) { char[] c = s.toCharArray(); c[0] = Character.toUpperCase(c[0]); return new String(c); } public static String stdAttribName(String attribName) { return lowerCaseString(attribName); } public static String stdPredicateName(String name) { return lowerCaseString(name); // NOTE: for BLNs, it's highly advisable for functions/predicates to be lower-case, because otherwise they can't be used in Prolog } public static String stdDomainName(String domainName) { return lowerCaseString(domainName); } public static String stdAttribStringValue(String strValue) { // make sure the value's first character is upper case char[] value = strValue.toCharArray(); value[0] = Character.toUpperCase(value[0]); // if there are spaces in the value, remove them and make the following letters upper case int len = 1; for(int i = 1; i < value.length;) { if(value[i] == ' ') { value[len++] = Character.toUpperCase(value[++i]); i++; } else value[len++] = value[i++]; } return new String(value, 0, len); } /** * verifies compatibility of the data with the data dictionary * and merges domains with overlapping value sets in the data dictionary */ public void check() throws DDException, Exception { // check objects for(Object obj : objects) { datadict.checkObject(obj); } // check relations for(Link link : this.links) { datadict.checkLink(link); } // check data dictionary consistency (non-overlapping domains, etc.) datadict.check(); } public static class AttributeClustering { public BasicClusterer<?> clusterer; public Domain<?> newDomain; } /** * performs clustering on the attributes for which it was specified in the data dictionary * @throws DDException * @throws Exception */ public HashMap<DDAttribute, AttributeClustering> doClustering(HashMap<DDAttribute, AttributeClustering> clusterers) throws DDException, Exception { System.out.println("clustering..."); if(clusterers == null) clusterers = new HashMap<DDAttribute, AttributeClustering>(); MultiIterator<Item> items = new MultiIterator<Item>(); items.add(objects); items.add(links); for(DDAttribute attrib : this.datadict.getAttributes()) { if(attrib.isDiscarded()) continue; if(attrib.requiresClustering()) { System.out.println(" " + attrib.getName()); AttributeClustering ac; ac = clusterers.get(attrib); if(ac != null) { applyClustering(attrib, items, ac); continue; } ac = attrib.doClustering(items); System.out.println(" " + ac.newDomain); clusterers.put(attrib, ac); } } return clusterers; } public HashMap<DDAttribute, AttributeClustering> doClustering() throws DDException, Exception { return doClustering(null); } public Database clone() { try { return (Database)super.clone(); } catch (CloneNotSupportedException e) { return null; } } public Collection<Link> getLinks() { return links; } /** * returns all links in which the given object appears * @param o * @return */ public Vector<Link> getLinks(Object o) { Vector<Link> v = new Vector<Link>(); for(Link l : this.links) { for(int i = 0; i < l.arguments.length; i++) if(l.arguments[i] == o) v.add(l); } return v; } public Collection<Object> getObjects() { return objects; } public void addObject(Object obj) throws DDException { if(objects.add(obj)) this.datadict.onCommitObject(obj); } public void addLink(Link l) throws DDException { if(links.add(l)) this.datadict.onCommitLink(l); } public DataDictionary getDataDictionary() { return datadict; } public void setDataDictionary(DataDictionary dd) { this.datadict = dd; } public static class Counters { protected HashMap<String, Integer> counters; public Counters() { counters = new HashMap<String, Integer>(); } public Integer inc(String name) { Integer c = counters.get(name); if(c == null) counters.put(name, c=new Integer(1)); else counters.put(name, c=new Integer(c+1)); return c; } public String toString() { return counters.toString(); } } /** * empties this database */ public void clear() { this.objects.clear(); this.links.clear(); } public Object getConstantAsObject(String objType, String constantName) throws DDException { HashMap<String, Object> name2obj = constantObjectsByType.get(objType); if(name2obj == null) { name2obj = new HashMap<String, Object>(); constantObjectsByType.put(objType, name2obj); } Object o = name2obj.get(constantName); if(o == null) { o = new Object(this, objType, constantName); o.commit(); name2obj.put(constantName, o); } return o; } public void printData() { for(Object o : objects) o.printData(); for(Link l : links) l.printData(); } }