package water.parser;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import water.H2O;
import water.score.*;
import water.util.Log;
/** Parse PMML models
*
* Full recursive-descent style parsing. MUCH easier to track the control
* flows than a SAX-style parser, and does not require the entire doc like a
* DOM-style. More tightly tied to the XML structure, but in theory PMML is
* a multi-vendor standard and fairly stable.
*
* Like a good R-D parser, uses a separate function for parsing each XML
* element. Each function expects to be at a particular parse-point
* (generally after the openning '<' and before the tag is parsed), and
* always leaves the parse just after the close-tag '>'. The semantic
* interpretation is then interleaved with the parsing, with higher levels
* passing down needed info to lower element levels, and lower levels
* directly returning results to the higher levels.
*
* @author <a href="mailto:cliffc@h2o.ai"></a>
* @version 1.0
*/
public class PMMLParser {
final InputStream _is; // Stream to read from
int [] _buf; // Pushback buffer
int _idx; // Pushback index
/** Features datatypes promoted by PMML spec. These appear before we know what
* kind of model we are parsing, so must be parsed globally (for all models). */
public static enum DataTypes {
DOUBLE("double"), INT("int"), BOOLEAN("boolean"), STRING("String");
final String _jname;
DataTypes( String jname ) { _jname = jname; }
public static DataTypes parse(String s) {return DataTypes.valueOf(s.toUpperCase()); }
public String jname() { return _jname; }
}
// Global (per-parse) type mappings. Examples:
// <DataField name="Species" optype="categorical" dataType="string">
// <DataField name="creditScore" dataType="double" optype="continuous" />
public final HashMap<String,DataTypes> _types = new HashMap();
// Global (per-parse) enum mappings. Examples:
//<DataField name="Species" optype="categorical" dataType="string">
// <Value value="setosa"/>
// <Value value="versicolor"/>
// <Value value="virginica"/>
//</DataField>
public final HashMap<String,String[]> _enums = new HashMap();
public static class ParseException extends RuntimeException {
public ParseException( String msg ) { super(msg); }
}
public static ScoreModel parse( InputStream is ) {
return new PMMLParser(is).parse();
}
private PMMLParser(InputStream is) { _is = is; _buf=new int[2]; }
private ScoreModel parse() {
skipWS().expect('<');
if( peek()=='?' ) pXMLVersion().skipWS().expect('<');
return pPMML();
}
// Parse/skip XML version element
private PMMLParser pXMLVersion() {
expect("?xml");
while( peek() != '?' ) { // Look for closing '?>'
String attr = skipWS().token();
String val = skipWS().expect('=').str();
}
return expect("?>");
}
// The whole PMML element. Breaks out the different model types.
private ScoreModel pPMML() {
expect("PMML").skipAttrs();
expect('>').skipWS().expect('<');
pGeneric("Header"); // Skip a generic XML subtree
skipWS().expect('<');
pDataDictionary();
String mtag = skipWS().expect('<').token();
ScoreModel scm = null;
if( "Scorecard" .equals(mtag) ) scm = ScorecardModel.parse(this);
//if( "MiningModel".equals(mtag) ) scm = RFScoreModel.parse(this);
skipWS().expect("</PMML>");
return scm;
}
// Skip generic XML subtree
public PMMLParser pGeneric(String hdr) {
String t = token();
assert hdr==null || t.equals(hdr);
skipAttrs();
if( peek()=='/' ) return expect("/>");
expect('>');
while( true ) {
if( get()=='<' ) {
if( peek()=='/' ) return expect('/').expect(t).expect('>');
pGeneric(null);
}
}
}
// Reads the DataDictionary element, accumulating fields & types
private PMMLParser pDataDictionary() {
expect("DataDictionary").skipAttrs();
expect('>');
while( skipWS().expect('<').peek() != '/' ) pDataField();
return expect("/DataDictionary>");
}
// Read a single field name & type, plus any enum/factor/level info
private PMMLParser pDataField() {
HashMap<String,String> attrs = expect("DataField").attrs();
String name = attrs.get("name");
_types.put(name, DataTypes.parse(attrs.get("dataType")));
if( peek()=='/' ) return expect("/>");
expect('>');
ArrayList<String> str = new ArrayList();
while( skipWS().expect('<').peek() != '/' ) str.add(pDataFieldValue());
String[] ss = str.toArray(new String[0]);
Arrays.sort(ss,null);
_enums.put(name,ss);
return expect("/DataField>");
}
// A single enum/level value
private String pDataFieldValue() {
expect("Value").skipWS().expect("value=");
String str = str();
expect("/>");
return str;
}
// Parse out an PMML predicate. Common across several models.
public Predicate pPredicate() {
String t = token();
HashMap<String,String> attrs = attrs();
if( "SimplePredicate" .equals(t) ) return pSimplePredicate(attrs);
if( "CompoundPredicate" .equals(t) ) return pCompoundPredicate(attrs);
if( "SimpleSetPredicate".equals(t) ) return pSimpleSetPredicate(attrs);
if( "True".equals(t) ) { expect("/>"); return new True(); }
expect("unhandled_predicate");
return null;
}
private Predicate pSimplePredicate(HashMap<String,String> attrs) {
expect("/>");
return Comparison.makeSimple(attrs.get("field"),
Operators.valueOf(attrs.get("operator")),
attrs.get("value"));
}
private Predicate pCompoundPredicate(HashMap<String,String> attrs) {
expect(">");
CompoundPredicate cp = CompoundPredicate.make(BooleanOperators.valueOf(attrs.get("booleanOperator")));
cp._l = skipWS().expect('<').pPredicate();
cp._r = skipWS().expect('<').pPredicate();
skipWS().expect("</CompoundPredicate>");
return cp;
}
private Predicate pSimpleSetPredicate(HashMap<String,String> attrs) {
expect('>');
IsIn in = IsIn.make(attrs.get("field"),
BooleanOperators.valueOf(attrs.get("booleanOperator")));
in._values = skipWS().expect('<').pArray();
skipWS().expect("</SimpleSetPredicate>");
return in;
}
private String[] pArray() {
HashMap<String,String> attrs = expect("Array").attrs();
expect('>');
int len = Integer.parseInt(attrs.get("n"));
assert attrs.get("type").equals("string");
String[] ss = new String[len];
for( int i=0; i<len; i++ ) {
int b = skipWS().peek();
// Allow both quoted and unquoted tokens
ss[i] = (b=='&' || b=='"') ? str() : token();
}
skipWS().expect("</Array>");
return ss;
}
public int get() {
if( _idx > 0 ) return _buf[--_idx];
try {
int b = _is.read();
if( b != -1 ) return b;
} catch( IOException ioe ) { Log.err(ioe); }
throw new ParseException("Premature EOF");
}
public int peek() {
if( _idx > 0 ) return _buf[_idx-1];
try {
int b = _is.read();
if( b != -1 ) return push(b);
} catch( IOException e ) { Log.err(e); }
throw new ParseException("Premature EOF");
}
int push( int b ) { return (_buf[_idx++] = b); }
public int qget() {
int b = get();
if( b!='&' ) return b;
expect("quot;");
return '"';
}
// Read from stream, skipping whitespace
public PMMLParser skipWS() {
int c;
while( Character.isWhitespace(c=get()) ) ;
push(c);
return this;
}
// Assert correct token is found
public PMMLParser expect( char tok ) {
char c = (char)get();
return c == tok ? this : barf(tok,c);
}
public PMMLParser expect( String toks ) {
for( int i=0; i<toks.length(); i++ )
expect(toks.charAt(i));
return this;
}
public PMMLParser barf( char tok, char c ) {
StringBuilder sb = new StringBuilder();
sb.append("Expected '").append(tok).append("' but found '").append(c).append("'");
int line=0;
for( int i=0; i<512; i++ ) {
try { c = (char)get(); } catch( ParseException ioe ) { break; }
sb.append(c);
if( c=='\n' && line++ > 2 ) break;
}
throw new ParseException(sb.toString());
}
// Read from stream a valid PMML token
public String token() {
int b = get();
if( !Character.isJavaIdentifierStart(b) )
throw new ParseException("Expected token start but found '"+(char)b+"'");
StringBuilder sb = new StringBuilder();
sb.append((char)b);
b = get();
while( Character.isJavaIdentifierPart(b) || b==':' ) {
sb.append((char)b);
b = get();
}
push(b);
return sb.toString();
}
// Read from stream a "string". Skips the trailing close-quote
private String str() {
int q = skipWS().qget();
if( q!='"' && q!='\'' )
throw new ParseException("Expected one of ' or \" but found '"+(char)q+"'");
StringBuilder sb = new StringBuilder();
int b = get();
while( b != q ) {
sb.append((char)b);
b = qget();
}
return sb.toString();
}
// Any number of attributes, or '/' or '>'
public HashMap<String,String> attrs() {
HashMap<String,String> attrs = null;
while( true ) {
int b = skipWS().peek();
if( b == '/' || b == '>' ) return attrs;
if( attrs == null ) attrs = new HashMap();
String attr = token();
String val = skipWS().expect('=').str();
attrs.put(attr,val);
}
}
public void skipAttrs() {
while( true ) {
int b = skipWS().peek();
if( b == '/' || b == '>' ) return;
while( (b=get())!= '=' ) ;
int q = skipWS().get();
if( q!='"' && q!='\'' )
throw new ParseException("Expected one of ' or \" but found '"+(char)q+"'");
while( (b=get())!= q ) ;
}
}
// -------------------------------------------------------------------------
// -------------------------------------------------------------------------
// Common PMML Operators
public static enum Operators {
lessOrEqual, lessThan, greaterOrEqual, greaterThan, equal, isMissing;
}
public static enum BooleanOperators {
isNotIn, and, or, isIn;
}
public static abstract class Predicate {
public abstract boolean match(Comparable value);
public abstract boolean match(String sval, double dval);
public abstract StringBuilder toJavaNum( StringBuilder sb, String jname );
public StringBuilder toJavaBool( StringBuilder sb, String jname ) { throw H2O.unimpl(); }
public StringBuilder toJavaStr( StringBuilder sb, String jname ) { throw H2O.unimpl(); }
public static Predicate makeSimple(String field, Operators op, String cons) {
if( cons==null ) {
assert op==Operators.isMissing;
return new IsMissing(field);
}
switch (op) {
case lessOrEqual : return new LessOrEqual (field,cons);
case lessThan : return new LessThan (field,cons);
case greaterOrEqual: return new GreaterOrEqual(field,cons);
case greaterThan : return new GreaterThan (field,cons);
case equal : return new Equals (field,cons);
default : throw new RuntimeException("missing "+field+" "+op+" "+cons);
}
}
public String unique_name() { throw H2O.unimpl(); }
}
public static abstract class Comparison extends Predicate {
// Used to define comparisons like:
// "income < 10000" which _name==income, and _str=="10000", _num==10000
public final String _name;// Feature name, e.g. "bad_email" or "income"
public final String _str; // Constant compare value as a String
public final double _num; // Constant compare value or NaN if not applicable
public final double _bool;// Constant boolean value or NaN if not applicable
public Comparison(String name, String str) {
_name = name;
_str = str;
_num = getNumber (str);// Convert to a 'double'
_bool= getBoolean(str);// Convert to a 'boolean'
}
public String unique_name() { return _name; }
}
/** Less or equal */
public static class LessOrEqual extends Comparison {
public LessOrEqual(String name, String str) { super(name,str); }
@Override public boolean match(Comparable value) {
if( !Double.isNaN(_num ) ) return getNumber (value) <= _num ;
if( !Double.isNaN(_bool) ) return getBoolean(value) <= _bool;
String s = getString(value);
return s==null ? false : s.compareTo(_str) <= 0;
}
@Override public boolean match(String sval, double dval) { return dval <= _num; }
@Override public String toString() { return "X<=" + _str; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append(jname).append("<=").append(_num);
}
}
public static class LessThan extends Comparison {
public LessThan(String name, String str) { super(name,str); }
@Override public boolean match(Comparable value) {
if( !Double.isNaN(_num ) ) return getNumber (value) < _num ;
if( !Double.isNaN(_bool) ) return getBoolean(value) < _bool;
String s = getString(value);
return s==null ? false : s.compareTo(_str) < 0;
}
@Override public boolean match(String sval, double dval) { return dval < _num; }
@Override public String toString() { return "X<" + _str; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append(jname).append("<").append(_num);
}
}
public static class GreaterOrEqual extends Comparison {
public GreaterOrEqual(String name, String con) { super(name,con); }
@Override public boolean match(Comparable value) {
if( !Double.isNaN(_num ) ) return getNumber (value) >= _num ;
if( !Double.isNaN(_bool) ) return getBoolean(value) >= _bool;
String s = getString(value);
return s==null ? false : s.compareTo(_str) >= 0;
}
@Override public boolean match(String sval, double dval) { return dval >= _num; }
@Override public String toString() { return "X>=" + _str; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append(jname).append(">=").append(_num);
}
}
public static class GreaterThan extends Comparison {
public GreaterThan(String name, String str) { super(name,str); }
@Override public boolean match(Comparable value) {
if( !Double.isNaN(_num ) ) return getNumber (value) > _num ;
if( !Double.isNaN(_bool) ) return getBoolean(value) > _bool;
String s = getString(value);
return s==null ? false : s.compareTo(_str) > 0;
}
@Override public boolean match(String sval, double dval) { return dval > _num; }
@Override public String toString() { return "X>" + _str; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append(jname).append(">").append(_num);
}
}
public static class IsMissing extends Predicate {
public final String _name; // Feature name, like 'dependents'
public IsMissing( String name ) { _name=name; }
@Override public boolean match(Comparable value) { return value==null; }
@Override public boolean match(String sval, double dval) { return Double.isNaN(dval); }
@Override public String toString() { return "isMissing"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append("Double.isNaN("+jname+")");
}
@Override public StringBuilder toJavaBool( StringBuilder sb, String jname ) {
return sb.append("Double.isNaN("+jname+")");
}
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) {
return sb.append(jname).append("==null");
}
public String unique_name() { return _name; }
}
public static class Equals extends Comparison {
public Equals(String name, String str) { super(name,str); }
@Override public boolean match(Comparable value) {
if( !Double.isNaN(_num ) ) return getNumber (value) == _num ;
if( !Double.isNaN(_bool) ) return getBoolean(value) == _bool;
String s = getString(value);
return s==null ? false : s.compareTo(_str) == 0;
}
@Override public boolean match(String sval, double dval) {
if( !Double.isNaN(_num ) ) return dval == _num ;
if( !Double.isNaN(_bool) ) return dval == _bool;
return _str.equals(sval);
}
@Override public String toString() { return "X==" + _str; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append(jname).append("==").append(_num);
}
@Override public StringBuilder toJavaBool( StringBuilder sb, String jname ) {
return sb.append(jname).append("==").append(_bool);
}
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) {
return sb.append("\"").append(_str).append("\".equals(").append(jname).append(")");
}
}
public static abstract class CompoundPredicate extends Predicate {
Predicate _l,_r;
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) { throw H2O.unimpl(); }
public StringBuilder makeNum(StringBuilder sb, String jname, String rel) {
sb.append("(");
_l.toJavaNum(sb,jname);
sb.append(" ").append(rel).append(" ");
_r.toJavaNum(sb,jname);
sb.append(")");
return sb;
}
public StringBuilder makeStr(StringBuilder sb, String jname, String rel) {
sb.append("(");
_l.toJavaStr(sb,jname);
sb.append(" ").append(rel).append(" ");
_r.toJavaStr(sb,jname);
sb.append(")");
return sb;
}
public static CompoundPredicate make(BooleanOperators op) {
switch( op ) {
case and: return new And();
case or : return new Or();
default : return null;
}
}
public String unique_name() { return _l.unique_name(); }
}
public static class And extends CompoundPredicate {
@Override public final boolean match(Comparable value) { return _l.match(value) && _r.match(value); }
@Override public final boolean match(String sval, double dval) { return _l.match(sval,dval) && _r.match(sval,dval); }
@Override public String toString() { return "(" + _l.toString() + " and " + _r.toString() + ")"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) { return makeNum(sb,jname,"&&"); }
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) { return makeStr(sb,jname,"&&"); }
}
public static class Or extends CompoundPredicate {
@Override public final boolean match(Comparable value) { return _l.match(value) || _r.match(value); }
@Override public final boolean match(String sval, double dval) { return _l.match(sval,dval) || _r.match(sval,dval); }
@Override public String toString() { return "(" + _l.toString() + " or " + _r.toString() + ")"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) { return makeNum(sb,jname,"||"); }
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) { return makeStr(sb,jname,"||"); }
}
public static class IsIn extends Predicate {
public final String _name; // Feature name, like 'state'
public String[] _values;
public IsIn(String name, String[] values) { _name=name; _values = values; }
@Override public boolean match(Comparable value) {
for( String t : _values ) if (t.equals(value)) return true;
return false;
}
@Override public boolean match(String sval, double dval) {
for( String t : _values ) if (t.equals(sval)) return true;
return false;
}
@Override public String toString() {
String x = "";
for( String s: _values ) x += s + " ";
return "X is in {" + x + "}"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) { throw H2O.unimpl(); }
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) {
for( String s : _values )
sb.append("\"").append(s).append("\".equals(").append(jname).append(") || ");
return sb.append("false");
}
public static IsIn make(String name, BooleanOperators op) {
switch( op ) {
case isIn : return new IsIn (name,null);
case isNotIn: return new IsNotIn(name,null);
default : return null;
}
}
public String unique_name() { return _name; }
}
public static class IsNotIn extends IsIn {
public IsNotIn(String name, String[] values) { super(name,values); }
@Override public boolean match(Comparable value) { return ! super.match(value); }
@Override public boolean match(String sval, double dval) { return ! super.match(sval,dval); }
@Override public String toString() { return "!("+super.toString()+")"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) { throw H2O.unimpl(); }
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) {
sb.append("!(");
super.toJavaStr(sb,jname);
return sb.append(")");
}
}
public static class True extends Predicate {
@Override public boolean match(Comparable value) { return true; }
@Override public boolean match(String sval, double dval) { return true; }
@Override public String toString() { return "true"; }
@Override public StringBuilder toJavaNum( StringBuilder sb, String jname ) {
return sb.append("true");
}
@Override public StringBuilder toJavaBool( StringBuilder sb, String jname ) {
return sb.append("true");
}
@Override public StringBuilder toJavaStr( StringBuilder sb, String jname ) {
return sb.append("true");
}
@Override public String unique_name() { return ""; }
}
// Happy Helper Methods for the generated code
public static double getNumber( HashMap<String,Comparable> row, String s ) {
return getNumber(row.get(s));
}
public static double getNumber( Comparable o ) {
// hint to the jit to do a instanceof breakdown tree
if( o instanceof Double ) return ((Double)o).doubleValue();
if( o instanceof Long ) return ((Long )o).doubleValue();
if( o instanceof Number ) return ((Number)o).doubleValue();
if( o instanceof String ) {
try { return Double.valueOf((String)o); } catch( Throwable t ) { }
}
return Double.NaN;
}
public static double getBoolean( HashMap<String,Comparable> row, String s ) {
return getBoolean(row.get(s));
}
public static double getBoolean( Comparable o ) {
if( o instanceof Boolean ) return ((Boolean)o) ? 1.0 : 0.0;
if( o instanceof String ) {
try {
if( "true" .equalsIgnoreCase((String) o) ) return 1.0;
if( "false".equalsIgnoreCase((String) o) ) return 0.0;
} catch( Throwable t ) { Log.err(t); }
}
return Double.NaN;
}
public static String getString( HashMap<String,Comparable> row, String s ) {
return getString(row.get(s));
}
public static String getString( Comparable o ) {
if( o instanceof String ) return (String)o;
return o == null ? null : o.toString();
}
}