/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.runtime.functionobjects; import java.util.HashMap; import org.apache.commons.math3.util.FastMath; import com.ibm.bi.dml.api.DMLScript; import com.ibm.bi.dml.runtime.DMLRuntimeException; import com.ibm.bi.dml.runtime.DMLScriptException; import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException; /** * Class with pre-defined set of objects. This class can not be instantiated elsewhere. * * Notes on commons.math FastMath: * * FastMath uses lookup tables and interpolation instead of native calls. * * The memory overhead for those tables is roughly 48KB in total (acceptable) * * Micro and application benchmarks showed significantly (30%-3x) performance improvements * for most operations; without loss of accuracy. * * atan / sqrt were 20% slower in FastMath and hence, we use Math there * * round / abs were equivalent in FastMath and hence, we use Math there * * Finally, there is just one argument against FastMath - The comparison heavily depends * on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt * which makes this operation very efficient; as soon as other operations like log/exp are * similarly compiled, we should rerun the micro benchmarks, and switch back if necessary. * */ public class Builtin extends ValueFunction { private static final long serialVersionUID = 3836744687789840574L; public enum BuiltinFunctionCode { INVALID, SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }; public BuiltinFunctionCode bFunc; private static final boolean FASTMATH = true; static public HashMap<String, BuiltinFunctionCode> String2BuiltinFunctionCode; static { String2BuiltinFunctionCode = new HashMap<String, BuiltinFunctionCode>(); String2BuiltinFunctionCode.put( "sin" , BuiltinFunctionCode.SIN); String2BuiltinFunctionCode.put( "cos" , BuiltinFunctionCode.COS); String2BuiltinFunctionCode.put( "tan" , BuiltinFunctionCode.TAN); String2BuiltinFunctionCode.put( "asin" , BuiltinFunctionCode.ASIN); String2BuiltinFunctionCode.put( "acos" , BuiltinFunctionCode.ACOS); String2BuiltinFunctionCode.put( "atan" , BuiltinFunctionCode.ATAN); String2BuiltinFunctionCode.put( "log" , BuiltinFunctionCode.LOG); String2BuiltinFunctionCode.put( "log_nz" , BuiltinFunctionCode.LOG_NZ); String2BuiltinFunctionCode.put( "min" , BuiltinFunctionCode.MIN); String2BuiltinFunctionCode.put( "max" , BuiltinFunctionCode.MAX); String2BuiltinFunctionCode.put( "maxindex" , BuiltinFunctionCode.MAXINDEX); String2BuiltinFunctionCode.put( "minindex" , BuiltinFunctionCode.MININDEX); String2BuiltinFunctionCode.put( "abs" , BuiltinFunctionCode.ABS); String2BuiltinFunctionCode.put( "sqrt" , BuiltinFunctionCode.SQRT); String2BuiltinFunctionCode.put( "exp" , BuiltinFunctionCode.EXP); String2BuiltinFunctionCode.put( "plogp" , BuiltinFunctionCode.PLOGP); String2BuiltinFunctionCode.put( "print" , BuiltinFunctionCode.PRINT); String2BuiltinFunctionCode.put( "nrow" , BuiltinFunctionCode.NROW); String2BuiltinFunctionCode.put( "ncol" , BuiltinFunctionCode.NCOL); String2BuiltinFunctionCode.put( "length" , BuiltinFunctionCode.LENGTH); String2BuiltinFunctionCode.put( "round" , BuiltinFunctionCode.ROUND); String2BuiltinFunctionCode.put( "stop" , BuiltinFunctionCode.STOP); String2BuiltinFunctionCode.put( "ceil" , BuiltinFunctionCode.CEIL); String2BuiltinFunctionCode.put( "floor" , BuiltinFunctionCode.FLOOR); String2BuiltinFunctionCode.put( "ucumk+" , BuiltinFunctionCode.CUMSUM); String2BuiltinFunctionCode.put( "ucum*" , BuiltinFunctionCode.CUMPROD); String2BuiltinFunctionCode.put( "ucummin", BuiltinFunctionCode.CUMMIN); String2BuiltinFunctionCode.put( "ucummax", BuiltinFunctionCode.CUMMAX); String2BuiltinFunctionCode.put( "inverse", BuiltinFunctionCode.INVERSE); String2BuiltinFunctionCode.put( "sprop", BuiltinFunctionCode.SPROP); String2BuiltinFunctionCode.put( "sigmoid", BuiltinFunctionCode.SIGMOID); String2BuiltinFunctionCode.put( "sel+", BuiltinFunctionCode.SELP); } // We should create one object for every builtin function that we support private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null; private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null; private static Builtin absObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null; private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null; private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null; private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null; private Builtin(BuiltinFunctionCode bf) { bFunc = bf; } public BuiltinFunctionCode getBuiltinFunctionCode() { return bFunc; } /** * * @param str * @return */ public static Builtin getBuiltinFnObject (String str) { BuiltinFunctionCode code = String2BuiltinFunctionCode.get(str); return getBuiltinFnObject( code ); } /** * * @param code * @return */ public static Builtin getBuiltinFnObject(BuiltinFunctionCode code) { if ( code == null ) return null; switch ( code ) { case SIN: if ( sinObj == null ) sinObj = new Builtin(BuiltinFunctionCode.SIN); return sinObj; case COS: if ( cosObj == null ) cosObj = new Builtin(BuiltinFunctionCode.COS); return cosObj; case TAN: if ( tanObj == null ) tanObj = new Builtin(BuiltinFunctionCode.TAN); return tanObj; case ASIN: if ( asinObj == null ) asinObj = new Builtin(BuiltinFunctionCode.ASIN); return asinObj; case ACOS: if ( acosObj == null ) acosObj = new Builtin(BuiltinFunctionCode.ACOS); return acosObj; case ATAN: if ( atanObj == null ) atanObj = new Builtin(BuiltinFunctionCode.ATAN); return atanObj; case LOG: if ( logObj == null ) logObj = new Builtin(BuiltinFunctionCode.LOG); return logObj; case LOG_NZ: if ( lognzObj == null ) lognzObj = new Builtin(BuiltinFunctionCode.LOG_NZ); return lognzObj; case MAX: if ( maxObj == null ) maxObj = new Builtin(BuiltinFunctionCode.MAX); return maxObj; case MAXINDEX: if ( maxindexObj == null ) maxindexObj = new Builtin(BuiltinFunctionCode.MAXINDEX); return maxindexObj; case MIN: if ( minObj == null ) minObj = new Builtin(BuiltinFunctionCode.MIN); return minObj; case MININDEX: if ( minindexObj == null ) minindexObj = new Builtin(BuiltinFunctionCode.MININDEX); return minindexObj; case ABS: if ( absObj == null ) absObj = new Builtin(BuiltinFunctionCode.ABS); return absObj; case SQRT: if ( sqrtObj == null ) sqrtObj = new Builtin(BuiltinFunctionCode.SQRT); return sqrtObj; case EXP: if ( expObj == null ) expObj = new Builtin(BuiltinFunctionCode.EXP); return expObj; case PLOGP: if ( plogpObj == null ) plogpObj = new Builtin(BuiltinFunctionCode.PLOGP); return plogpObj; case PRINT: if ( printObj == null ) printObj = new Builtin(BuiltinFunctionCode.PRINT); return printObj; case NROW: if ( nrowObj == null ) nrowObj = new Builtin(BuiltinFunctionCode.NROW); return nrowObj; case NCOL: if ( ncolObj == null ) ncolObj = new Builtin(BuiltinFunctionCode.NCOL); return ncolObj; case LENGTH: if ( lengthObj == null ) lengthObj = new Builtin(BuiltinFunctionCode.LENGTH); return lengthObj; case ROUND: if ( roundObj == null ) roundObj = new Builtin(BuiltinFunctionCode.ROUND); return roundObj; case CEIL: if ( ceilObj == null ) ceilObj = new Builtin(BuiltinFunctionCode.CEIL); return ceilObj; case FLOOR: if ( floorObj == null ) floorObj = new Builtin(BuiltinFunctionCode.FLOOR); return floorObj; case CUMSUM: if ( cumsumObj == null ) cumsumObj = new Builtin(BuiltinFunctionCode.CUMSUM); return cumsumObj; case CUMPROD: if ( cumprodObj == null ) cumprodObj = new Builtin(BuiltinFunctionCode.CUMPROD); return cumprodObj; case CUMMIN: if ( cumminObj == null ) cumminObj = new Builtin(BuiltinFunctionCode.CUMMIN); return cumminObj; case CUMMAX: if ( cummaxObj == null ) cummaxObj = new Builtin(BuiltinFunctionCode.CUMMAX); return cummaxObj; case INVERSE: if ( inverseObj == null ) inverseObj = new Builtin(BuiltinFunctionCode.INVERSE); return inverseObj; case STOP: if ( stopObj == null ) stopObj = new Builtin(BuiltinFunctionCode.STOP); return stopObj; case SPROP: if ( spropObj == null ) spropObj = new Builtin(BuiltinFunctionCode.SPROP); return spropObj; case SIGMOID: if ( sigmoidObj == null ) sigmoidObj = new Builtin(BuiltinFunctionCode.SIGMOID); return sigmoidObj; case SELP: if ( selpObj == null ) selpObj = new Builtin(BuiltinFunctionCode.SELP); return selpObj; default: // Unknown code --> return null return null; } } public Object clone() throws CloneNotSupportedException { // cloning is not supported for singleton classes throw new CloneNotSupportedException(); } public boolean checkArity(int _arity) throws DMLUnsupportedOperationException { switch (bFunc) { case ABS: case SIN: case COS: case TAN: case ASIN: case ACOS: case ATAN: case SQRT: case EXP: case PLOGP: case NROW: case NCOL: case LENGTH: case ROUND: case PRINT: case MAXINDEX: case MININDEX: case STOP: case CEIL: case FLOOR: case CUMSUM: case INVERSE: case SPROP: case SIGMOID: case SELP: return (_arity == 1); case LOG: case LOG_NZ: return (_arity == 1 || _arity == 2); case MAX: case MIN: return (_arity == 2); default: throw new DMLUnsupportedOperationException("checkNumberOfOperands(): Unknown opcode: " + bFunc); } } public double execute (double in) throws DMLRuntimeException { switch(bFunc) { case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in); case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in); case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in); case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in); case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in); case ATAN: return Math.atan(in); //faster in Math case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in); case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in); case LOG: //if ( in <= 0 ) // throw new DMLRuntimeException("Builtin.execute(): logarithm can only be computed for non-negative numbers (input = " + in + ")."); // for negative numbers, Math.log will return NaN return FASTMATH ? FastMath.log(in) : Math.log(in); case LOG_NZ: return (in==0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in); case ABS: return Math.abs(in); //no need for FastMath case SQRT: //if ( in < 0 ) // throw new DMLRuntimeException("Builtin.execute(): squareroot can only be computed for non-negative numbers (input = " + in + ")."); return Math.sqrt(in); //faster in Math case PLOGP: if (in == 0.0) return 0.0; else if (in < 0) return Double.NaN; else return (in * (FASTMATH ? FastMath.log(in) : Math.log(in))); case EXP: return FASTMATH ? FastMath.exp(in) : Math.exp(in); case ROUND: return Math.round(in); //no need for FastMath case SPROP: //sample proportion: P*(1-P) return in * (1 - in); case SIGMOID: //sigmoid: 1/(1+exp(-x)) return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in)); case SELP: //select positive: x*(x>0) return (in > 0) ? in : 0; default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } public double execute (long in) throws DMLRuntimeException { return this.execute((double)in); } /* * Builtin functions with two inputs */ public double execute (double in1, double in2) throws DMLRuntimeException { switch(bFunc) { /* * Arithmetic relational operators (==, !=, <=, >=) must be instead of * <code>Double.compare()</code> due to the inconsistencies in the way * NaN and -0.0 are handled. The behavior of methods in * <code>Double</code> class are designed mainly to make Java * collections work properly. For more details, see the help for * <code>Double.equals()</code> and <code>Double.comapreTo()</code>. */ case MAX: case CUMMAX: //return (Double.compare(in1, in2) >= 0 ? in1 : in2); return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: //return (Double.compare(in1, in2) <= 0 ? in1 : in2); return (in1 <= in2 ? in1 : in2); // *** HACK ALERT *** HACK ALERT *** HACK ALERT *** // rowIndexMax() and its siblings require comparing four values, but // the aggregation API only allows two values. So the execute() // method receives as its argument the two cell values to be // compared and performs just the value part of the comparison. We // return an integer cast down to a double, since the aggregation // API doesn't have any way to return anything but a double. The // integer returned takes on three posssible values: // // . 0 => keep the index associated with in1 // // . 1 => use the index associated with in2 // // . 2 => use whichever index is higher (tie in value) // case MAXINDEX: if (in1 == in2) { return 2; } else if (in1 > in2) { return 1; } else { // in1 < in2 return 0; } case MININDEX: if (in1 == in2) { return 2; } else if (in1 < in2) { return 1; } else { // in1 > in2 return 0; } // *** END HACK *** case LOG: //if ( in1 <= 0 ) // throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers."); if( FASTMATH ) return (FastMath.log(in1)/FastMath.log(in2)); else return (Math.log(in1)/Math.log(in2)); case LOG_NZ: if( FASTMATH ) return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2)); else return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2)); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } /** * Simplified version without exception handling * * @param in1 * @param in2 * @return */ public double execute2(double in1, double in2) { switch(bFunc) { case MAX: case CUMMAX: //return (Double.compare(in1, in2) >= 0 ? in1 : in2); return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: //return (Double.compare(in1, in2) <= 0 ? in1 : in2); return (in1 <= in2 ? in1 : in2); case MAXINDEX: return (in1 >= in2) ? 1 : 0; case MININDEX: return (in1 <= in2) ? 1 : 0; default: // For performance reasons, avoid throwing an exception return -1; } } public double execute (long in1, long in2) throws DMLRuntimeException { switch(bFunc) { case MAX: case CUMMAX: return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: return (in1 <= in2 ? in1 : in2); case MAXINDEX: return (in1 >= in2) ? 1 : 0; case MININDEX: return (in1 <= in2) ? 1 : 0; case LOG: //if ( in1 <= 0 ) // throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers."); if( FASTMATH ) return (FastMath.log(in1)/FastMath.log(in2)); else return (Math.log(in1)/Math.log(in2)); case LOG_NZ: if( FASTMATH ) return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2)); else return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2)); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } // currently, it is used only for PRINT and STOP public String execute (String in1) throws DMLRuntimeException { switch (bFunc) { case PRINT: if (!DMLScript.suppressPrint2Stdout()) System.out.println(in1); return null; case STOP: throw new DMLScriptException(in1); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } }