package water.score; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import javassist.*; import water.parser.PMMLParser.DataTypes; import water.parser.PMMLParser.Predicate; import water.parser.PMMLParser; import water.score.ScoreModel; import water.util.Log; import water.util.Log.Tag.Sys; /** * Scorecard model - decision table. */ public class ScorecardModel extends ScoreModel { /** Initial score */ final double _initialScore; /** The rules to each for each feature, they map 1-to-1 with the Model's * column list. */ final RuleTable _rules[]; /** Score this model on the specified row of data. */ public double score(final HashMap<String, Comparable> row ) { // By default, use the scoring interpreter. The Builder JITs a new // subclass with an overloaded 'score(row)' call which has a JIT'd version // of the rules. i.e., calling 'score(row)' on the returned ScorecardModel // instance runs the fast version, but you can cast to the base version if // you want the interpreter. return score_interpreter(row); } // Use the rule interpreter public double score_interpreter(final HashMap<String, Comparable> row ) { double score = _initialScore; for( int i=0; i<_rules.length; i++ ) score += _rules[i].score(row.get(_colNames[i])); return score; } public double score(int[] MAP, String[] SS, double[] DS) { return score_interpreter(MAP,SS,DS); } private double score_interpreter(int[] MAP, String[] SS, double[] DS) { double score = _initialScore; for( int i=0; i<_rules.length; i++ ) { int idx = MAP[i]; String ss = idx==-1 ? null : SS[idx]; double dd = idx==-1 ? Double.NaN : DS[idx]; double s = _rules[i].score(ss,dd); score += s; } return score; } // JIT a score method with signature 'double score(HashMap row)' public void makeScoreHashMethod(CtClass scClass) { // Map of previously extracted PMML names, and their java equivs HashMap<String,String> vars = new HashMap<String,String>(); StringBuilder sb = new StringBuilder(); sb.append("double score( java.util.HashMap row ) {\n"+ " double score = "+_initialScore+";\n"); try { for( int i=0; i<_rules.length; i++ ) _rules[i].makeFeatureHashMethod(sb,vars,scClass); sb.append(" return score;\n}\n"); CtMethod happyMethod = CtMethod.make(sb.toString(),scClass); scClass.addMethod(happyMethod); } catch( Exception re ) { Log.err(Sys.SCORM,"Crashing:"+sb.toString(), new RuntimeException(re)); } } public void makeScoreAryMethod(CtClass scClass) { // Map of previously extracted PMML names, and their java equivs HashMap<String,String> vars = new HashMap<String,String>(); StringBuilder sb = new StringBuilder(); sb.append("double score( int[] MAP, java.lang.String[] SS, double[] DS ) {\n"+ " double score = "+_initialScore+";\n"); try { for( int i=0; i<_rules.length; i++ ) _rules[i].makeFeatureAryMethod(sb,vars,scClass,i); sb.append(" return score;\n}\n"); CtMethod happyMethod = CtMethod.make(sb.toString(),scClass); scClass.addMethod(happyMethod); } catch( Exception re ) { Log.err(Sys.SCORM,"Crashing:"+sb.toString(), new RuntimeException(re)); } } // Return the java-equivalent from the PMML variable name, creating and // installing it as needed. If the value is created, we also emit Java code // to emit it at runtime. public static String getName( String pname, DataTypes type, StringBuilder sb ) { String jname = xml2jname(pname); // Emit the code to do the load return jname; } /** Feature decision table */ public static class RuleTable { final String _name; final Rule[] _rule; final DataTypes _type; final double _baseScore; public RuleTable(String name, DataTypes type, Rule[] decisions, double baseScore) { _name = name; _type = type; _rule = decisions; _baseScore = baseScore; } public void makeFeatureHashMethod( StringBuilder sbParent, HashMap<String,String> vars, CtClass scClass ) { if( _type == null ) { Log.warn("Ignore untyped feature "+_name); return; } String jname = xml2jname(_name); StringBuilder sb = new StringBuilder(); sb.append("double ").append(jname).append("( java.util.HashMap row ) {\n"+ " double score = 0;\n"); switch( _type ) { case STRING : sb.append(" String " ); break; case BOOLEAN: sb.append(" double "); break; default : sb.append(" double " ); break; } sb.append(jname); switch( _type ) { case STRING : sb.append(" = water.parser.PMMLParser.getString (row,\""); break; case BOOLEAN: sb.append(" = water.parser.PMMLParser.getBoolean(row,\"" ); break; default : sb.append(" = water.parser.PMMLParser.getNumber (row,\"" ); break; } sb.append(_name).append("\");\n"); sb.append(" if( false ) ;\n"); for (Rule r : _rule) if( _type == DataTypes.STRING) r.toJavaStr(sb,jname); else if( _type == DataTypes.BOOLEAN) r.toJavaBool(sb,jname); else r.toJavaNum(sb,jname); // close the dangling 'else' from all the prior rules sb.append(" return score;\n}\n"); sbParent.append(" score += ").append(jname).append("(row);\n"); // Now install the method try { CtMethod happyMethod = CtMethod.make(sb.toString(),scClass); scClass.addMethod(happyMethod); } catch( Exception re ) { Log.err(Sys.SCORM,"Crashing:"+sb.toString(), new RuntimeException(re)); } } public void makeFeatureAryMethod( StringBuilder sbParent, HashMap<String,String> vars, CtClass scClass, int fidx ) { if( _type == null ) return; // Untyped, ignore String jname = xml2jname(_name); StringBuilder sb = new StringBuilder(); sb.append("double ").append(jname); sb.append("( int[]MAP, java.lang.String[]SS, double[]DS ) {\n"+ " double score = 0;\n"+ " int didx=MAP[").append(fidx).append("];\n"); switch( _type ) { case STRING : sb.append(" String " ); break; case BOOLEAN: sb.append(" boolean "); break; default : sb.append(" double " ); break; } sb.append(jname); switch( _type ) { case STRING : sb.append(" = didx==-1 ? null : SS[didx];\n"); break; case BOOLEAN: sb.append(" = didx==-1 ? false : DS[didx]==1.0;\n"); break; default : sb.append(" = didx==-1 ? Double.NaN : DS[didx];\n" ); break; } sb.append(" if( false ) ;\n"); for (Rule r : _rule) if( _type == DataTypes.STRING) r.toJavaStr(sb,jname); else if( _type == DataTypes.BOOLEAN) r.toJavaBool(sb,jname); else r.toJavaNum(sb,jname); // close the dangling 'else' from all the prior rules sb.append(" return score;\n}\n"); sbParent.append(" score += ").append(jname).append("(MAP,SS,DS);\n"); // Now install the method try { CtMethod happyMethod = CtMethod.make(sb.toString(),scClass); scClass.addMethod(happyMethod); } catch( Exception re ) { Log.err(Sys.SCORM,"Crashing:"+sb.toString(), new RuntimeException(re)); } } // The rule interpreter double score(Comparable value) { double score = 0; for (Rule r : _rule) { if( r.match(value) ) { score += r._score; break; } } return score; } double score(String s, double d) { double score = 0; for (Rule r : _rule) { if( r.match(s,d) ) { score += r._score; break; } } return score; } @Override public String toString() { return "RuleTable [_name=" + _name + ", _rule=" + Arrays.toString(_rule) + ", _type=" + _type + " baseScore="+_baseScore+"]"; } } /** Scorecard decision rule */ public static class Rule { final double _score; final Predicate _predicate; public Rule(double score, Predicate pred) { assert pred != null; _score = score; _predicate = pred; } boolean match(Comparable value) { return _predicate.match(value); } boolean match(String s, double d) { return _predicate.match(s,d); } @Override public String toString() { return _predicate.toString() + " => " + _score; } public StringBuilder toJavaNum( StringBuilder sb, String jname ) { sb.append(" else if( "); return _predicate.toJavaNum(sb,jname).append(" ) score += ").append(_score).append(";\n"); } public StringBuilder toJavaBool( StringBuilder sb, String jname ) { sb.append(" else if( "); return _predicate.toJavaBool(sb,jname).append(" ) score += ").append(_score).append(";\n"); } public StringBuilder toJavaStr( StringBuilder sb, String jname ) { sb.append(" else if( "); return _predicate.toJavaStr(sb,jname).append(" ) score += ").append(_score).append(";\n"); } String unique_name() { return _predicate.unique_name(); } } @Override public String toString() { return super.toString()+", _initialScore=" + _initialScore; } private ScorecardModel(String name, String[] colNames, double initialScore, RuleTable[] rules) { super(name,colNames); assert colNames.length==rules.length; _initialScore = initialScore; _rules = rules; } protected ScorecardModel(ScorecardModel base) { this(base._name,base._colNames,base._initialScore,base._rules); } /** Scorecard model builder: JIT a subclass with the fast version wired in to * 'score(row)' */ public static ScorecardModel make(final String name, final double initialScore, RuleTable[] rules) { // Get the list of features String[] colNames = new String[rules.length]; for( int i=0; i<rules.length; i++ ) colNames[i] = rules[i]._name; // javassist support for rewriting class files ClassPool _pool = ClassPool.getDefault(); try { // Make a javassist class in the java hierarchy String cname = uniqueClassName(name); CtClass scClass = _pool.makeClass(cname); CtClass baseClass = _pool.get("water.score.ScorecardModel"); // Full Name Lookup scClass.setSuperclass(baseClass); // Produce the scoring method(s) ScorecardModel scm = new ScorecardModel(name, colNames,initialScore, rules); scm.makeScoreHashMethod(scClass); scm.makeScoreAryMethod(scClass); // Produce a 1-arg constructor String cons = " public "+cname+"(water.score.ScorecardModel base) { super(base); }"; CtConstructor happyConst = CtNewConstructor.make(cons,scClass); scClass.addConstructor(happyConst); Class myClass = scClass.toClass(ScorecardModel.class.getClassLoader(), null); Constructor<ScorecardModel> co = myClass.getConstructor(ScorecardModel.class); ScorecardModel jitted_scm = co.newInstance(scm); return jitted_scm; } catch( Exception e ) { Log.err(Sys.SCORM,"Javassist failed",e); } return null; } // ------------------------------------------------------------------------- public static ScorecardModel parse( PMMLParser pmml ) { HashMap<String,String> attrs = pmml.attrs(); pmml.expect('>'); pmml.skipWS().expect('<').pGeneric("MiningSchema"); pmml.skipWS().expect('<').pGeneric("Output"); pmml.skipWS().expect('<'); RuleTable[] rules = pCharacteristics(pmml); pmml.skipWS().expect("</Scorecard>"); String is = attrs.get("initialScore"); double initialScore = is==null?0:PMMLParser.getNumber(is); return make(attrs.get("modelName"), initialScore, rules); } private static RuleTable[] pCharacteristics( PMMLParser pmml ) { pmml.expect("Characteristics>"); ArrayList<RuleTable> rts = new ArrayList(); while( pmml.skipWS().expect('<').peek() != '/' ) rts.add(pCharacteristic(pmml)); pmml.expect("/Characteristics>"); return rts.toArray(new RuleTable[0]); } private static RuleTable pCharacteristic( PMMLParser pmml ) { HashMap<String,String> attrs = pmml.expect("Characteristic").attrs(); pmml.expect('>'); ArrayList<Rule> rules = new ArrayList(); while( pmml.skipWS().expect('<').peek() != '/' ) rules.add(pAttribute(pmml)); pmml.expect("/Characteristic>"); String name = rules.get(0).unique_name(); DataTypes t = pmml._types.get(name); String bls = attrs.get("baselineScore"); double baseScore = bls == null?0:PMMLParser.getNumber(bls); return new RuleTable(name,t,rules.toArray(new Rule[0]),baseScore); } private static Rule pAttribute( PMMLParser pmml ) { HashMap<String,String> attrs = pmml.expect("Attribute").attrs(); pmml.expect('>').skipWS().expect('<'); Predicate pred = pmml.pPredicate(); pmml.skipWS().expect("</Attribute>"); String ps = attrs.get("partialScore"); double partialScore = ps==null?0:PMMLParser.getNumber(ps); return new Rule(partialScore,pred); } }