/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.functionobjects;
import java.util.HashMap;
import org.apache.commons.math3.util.FastMath;
import com.ibm.bi.dml.api.DMLScript;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLScriptException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
/**
* Class with pre-defined set of objects. This class can not be instantiated elsewhere.
*
* Notes on commons.math FastMath:
* * FastMath uses lookup tables and interpolation instead of native calls.
* * The memory overhead for those tables is roughly 48KB in total (acceptable)
* * Micro and application benchmarks showed significantly (30%-3x) performance improvements
* for most operations; without loss of accuracy.
* * atan / sqrt were 20% slower in FastMath and hence, we use Math there
* * round / abs were equivalent in FastMath and hence, we use Math there
* * Finally, there is just one argument against FastMath - The comparison heavily depends
* on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
* which makes this operation very efficient; as soon as other operations like log/exp are
* similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
*
*/
public class Builtin extends ValueFunction
{
private static final long serialVersionUID = 3836744687789840574L;
public enum BuiltinFunctionCode { INVALID, SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP };
public BuiltinFunctionCode bFunc;
private static final boolean FASTMATH = true;
static public HashMap<String, BuiltinFunctionCode> String2BuiltinFunctionCode;
static {
String2BuiltinFunctionCode = new HashMap<String, BuiltinFunctionCode>();
String2BuiltinFunctionCode.put( "sin" , BuiltinFunctionCode.SIN);
String2BuiltinFunctionCode.put( "cos" , BuiltinFunctionCode.COS);
String2BuiltinFunctionCode.put( "tan" , BuiltinFunctionCode.TAN);
String2BuiltinFunctionCode.put( "asin" , BuiltinFunctionCode.ASIN);
String2BuiltinFunctionCode.put( "acos" , BuiltinFunctionCode.ACOS);
String2BuiltinFunctionCode.put( "atan" , BuiltinFunctionCode.ATAN);
String2BuiltinFunctionCode.put( "log" , BuiltinFunctionCode.LOG);
String2BuiltinFunctionCode.put( "log_nz" , BuiltinFunctionCode.LOG_NZ);
String2BuiltinFunctionCode.put( "min" , BuiltinFunctionCode.MIN);
String2BuiltinFunctionCode.put( "max" , BuiltinFunctionCode.MAX);
String2BuiltinFunctionCode.put( "maxindex" , BuiltinFunctionCode.MAXINDEX);
String2BuiltinFunctionCode.put( "minindex" , BuiltinFunctionCode.MININDEX);
String2BuiltinFunctionCode.put( "abs" , BuiltinFunctionCode.ABS);
String2BuiltinFunctionCode.put( "sqrt" , BuiltinFunctionCode.SQRT);
String2BuiltinFunctionCode.put( "exp" , BuiltinFunctionCode.EXP);
String2BuiltinFunctionCode.put( "plogp" , BuiltinFunctionCode.PLOGP);
String2BuiltinFunctionCode.put( "print" , BuiltinFunctionCode.PRINT);
String2BuiltinFunctionCode.put( "nrow" , BuiltinFunctionCode.NROW);
String2BuiltinFunctionCode.put( "ncol" , BuiltinFunctionCode.NCOL);
String2BuiltinFunctionCode.put( "length" , BuiltinFunctionCode.LENGTH);
String2BuiltinFunctionCode.put( "round" , BuiltinFunctionCode.ROUND);
String2BuiltinFunctionCode.put( "stop" , BuiltinFunctionCode.STOP);
String2BuiltinFunctionCode.put( "ceil" , BuiltinFunctionCode.CEIL);
String2BuiltinFunctionCode.put( "floor" , BuiltinFunctionCode.FLOOR);
String2BuiltinFunctionCode.put( "ucumk+" , BuiltinFunctionCode.CUMSUM);
String2BuiltinFunctionCode.put( "ucum*" , BuiltinFunctionCode.CUMPROD);
String2BuiltinFunctionCode.put( "ucummin", BuiltinFunctionCode.CUMMIN);
String2BuiltinFunctionCode.put( "ucummax", BuiltinFunctionCode.CUMMAX);
String2BuiltinFunctionCode.put( "inverse", BuiltinFunctionCode.INVERSE);
String2BuiltinFunctionCode.put( "sprop", BuiltinFunctionCode.SPROP);
String2BuiltinFunctionCode.put( "sigmoid", BuiltinFunctionCode.SIGMOID);
String2BuiltinFunctionCode.put( "sel+", BuiltinFunctionCode.SELP);
}
// We should create one object for every builtin function that we support
private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null;
private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null;
private static Builtin absObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null;
private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null;
private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null;
private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;
private Builtin(BuiltinFunctionCode bf) {
bFunc = bf;
}
public BuiltinFunctionCode getBuiltinFunctionCode() {
return bFunc;
}
/**
*
* @param str
* @return
*/
public static Builtin getBuiltinFnObject (String str)
{
BuiltinFunctionCode code = String2BuiltinFunctionCode.get(str);
return getBuiltinFnObject( code );
}
/**
*
* @param code
* @return
*/
public static Builtin getBuiltinFnObject(BuiltinFunctionCode code)
{
if ( code == null )
return null;
switch ( code ) {
case SIN:
if ( sinObj == null )
sinObj = new Builtin(BuiltinFunctionCode.SIN);
return sinObj;
case COS:
if ( cosObj == null )
cosObj = new Builtin(BuiltinFunctionCode.COS);
return cosObj;
case TAN:
if ( tanObj == null )
tanObj = new Builtin(BuiltinFunctionCode.TAN);
return tanObj;
case ASIN:
if ( asinObj == null )
asinObj = new Builtin(BuiltinFunctionCode.ASIN);
return asinObj;
case ACOS:
if ( acosObj == null )
acosObj = new Builtin(BuiltinFunctionCode.ACOS);
return acosObj;
case ATAN:
if ( atanObj == null )
atanObj = new Builtin(BuiltinFunctionCode.ATAN);
return atanObj;
case LOG:
if ( logObj == null )
logObj = new Builtin(BuiltinFunctionCode.LOG);
return logObj;
case LOG_NZ:
if ( lognzObj == null )
lognzObj = new Builtin(BuiltinFunctionCode.LOG_NZ);
return lognzObj;
case MAX:
if ( maxObj == null )
maxObj = new Builtin(BuiltinFunctionCode.MAX);
return maxObj;
case MAXINDEX:
if ( maxindexObj == null )
maxindexObj = new Builtin(BuiltinFunctionCode.MAXINDEX);
return maxindexObj;
case MIN:
if ( minObj == null )
minObj = new Builtin(BuiltinFunctionCode.MIN);
return minObj;
case MININDEX:
if ( minindexObj == null )
minindexObj = new Builtin(BuiltinFunctionCode.MININDEX);
return minindexObj;
case ABS:
if ( absObj == null )
absObj = new Builtin(BuiltinFunctionCode.ABS);
return absObj;
case SQRT:
if ( sqrtObj == null )
sqrtObj = new Builtin(BuiltinFunctionCode.SQRT);
return sqrtObj;
case EXP:
if ( expObj == null )
expObj = new Builtin(BuiltinFunctionCode.EXP);
return expObj;
case PLOGP:
if ( plogpObj == null )
plogpObj = new Builtin(BuiltinFunctionCode.PLOGP);
return plogpObj;
case PRINT:
if ( printObj == null )
printObj = new Builtin(BuiltinFunctionCode.PRINT);
return printObj;
case NROW:
if ( nrowObj == null )
nrowObj = new Builtin(BuiltinFunctionCode.NROW);
return nrowObj;
case NCOL:
if ( ncolObj == null )
ncolObj = new Builtin(BuiltinFunctionCode.NCOL);
return ncolObj;
case LENGTH:
if ( lengthObj == null )
lengthObj = new Builtin(BuiltinFunctionCode.LENGTH);
return lengthObj;
case ROUND:
if ( roundObj == null )
roundObj = new Builtin(BuiltinFunctionCode.ROUND);
return roundObj;
case CEIL:
if ( ceilObj == null )
ceilObj = new Builtin(BuiltinFunctionCode.CEIL);
return ceilObj;
case FLOOR:
if ( floorObj == null )
floorObj = new Builtin(BuiltinFunctionCode.FLOOR);
return floorObj;
case CUMSUM:
if ( cumsumObj == null )
cumsumObj = new Builtin(BuiltinFunctionCode.CUMSUM);
return cumsumObj;
case CUMPROD:
if ( cumprodObj == null )
cumprodObj = new Builtin(BuiltinFunctionCode.CUMPROD);
return cumprodObj;
case CUMMIN:
if ( cumminObj == null )
cumminObj = new Builtin(BuiltinFunctionCode.CUMMIN);
return cumminObj;
case CUMMAX:
if ( cummaxObj == null )
cummaxObj = new Builtin(BuiltinFunctionCode.CUMMAX);
return cummaxObj;
case INVERSE:
if ( inverseObj == null )
inverseObj = new Builtin(BuiltinFunctionCode.INVERSE);
return inverseObj;
case STOP:
if ( stopObj == null )
stopObj = new Builtin(BuiltinFunctionCode.STOP);
return stopObj;
case SPROP:
if ( spropObj == null )
spropObj = new Builtin(BuiltinFunctionCode.SPROP);
return spropObj;
case SIGMOID:
if ( sigmoidObj == null )
sigmoidObj = new Builtin(BuiltinFunctionCode.SIGMOID);
return sigmoidObj;
case SELP:
if ( selpObj == null )
selpObj = new Builtin(BuiltinFunctionCode.SELP);
return selpObj;
default:
// Unknown code --> return null
return null;
}
}
public Object clone() throws CloneNotSupportedException {
// cloning is not supported for singleton classes
throw new CloneNotSupportedException();
}
public boolean checkArity(int _arity) throws DMLUnsupportedOperationException {
switch (bFunc) {
case ABS:
case SIN:
case COS:
case TAN:
case ASIN:
case ACOS:
case ATAN:
case SQRT:
case EXP:
case PLOGP:
case NROW:
case NCOL:
case LENGTH:
case ROUND:
case PRINT:
case MAXINDEX:
case MININDEX:
case STOP:
case CEIL:
case FLOOR:
case CUMSUM:
case INVERSE:
case SPROP:
case SIGMOID:
case SELP:
return (_arity == 1);
case LOG:
case LOG_NZ:
return (_arity == 1 || _arity == 2);
case MAX:
case MIN:
return (_arity == 2);
default:
throw new DMLUnsupportedOperationException("checkNumberOfOperands(): Unknown opcode: " + bFunc);
}
}
public double execute (double in) throws DMLRuntimeException {
switch(bFunc) {
case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in);
case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in);
case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in);
case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in);
case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in);
case ATAN: return Math.atan(in); //faster in Math
case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in);
case LOG:
//if ( in <= 0 )
// throw new DMLRuntimeException("Builtin.execute(): logarithm can only be computed for non-negative numbers (input = " + in + ").");
// for negative numbers, Math.log will return NaN
return FASTMATH ? FastMath.log(in) : Math.log(in);
case LOG_NZ:
return (in==0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);
case ABS:
return Math.abs(in); //no need for FastMath
case SQRT:
//if ( in < 0 )
// throw new DMLRuntimeException("Builtin.execute(): squareroot can only be computed for non-negative numbers (input = " + in + ").");
return Math.sqrt(in); //faster in Math
case PLOGP:
if (in == 0.0)
return 0.0;
else if (in < 0)
return Double.NaN;
else
return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));
case EXP:
return FASTMATH ? FastMath.exp(in) : Math.exp(in);
case ROUND:
return Math.round(in); //no need for FastMath
case SPROP:
//sample proportion: P*(1-P)
return in * (1 - in);
case SIGMOID:
//sigmoid: 1/(1+exp(-x))
return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));
case SELP:
//select positive: x*(x>0)
return (in > 0) ? in : 0;
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
public double execute (long in) throws DMLRuntimeException {
return this.execute((double)in);
}
/*
* Builtin functions with two inputs
*/
public double execute (double in1, double in2) throws DMLRuntimeException {
switch(bFunc) {
/*
* Arithmetic relational operators (==, !=, <=, >=) must be instead of
* <code>Double.compare()</code> due to the inconsistencies in the way
* NaN and -0.0 are handled. The behavior of methods in
* <code>Double</code> class are designed mainly to make Java
* collections work properly. For more details, see the help for
* <code>Double.equals()</code> and <code>Double.comapreTo()</code>.
*/
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// rowIndexMax() and its siblings require comparing four values, but
// the aggregation API only allows two values. So the execute()
// method receives as its argument the two cell values to be
// compared and performs just the value part of the comparison. We
// return an integer cast down to a double, since the aggregation
// API doesn't have any way to return anything but a double. The
// integer returned takes on three posssible values: //
// . 0 => keep the index associated with in1 //
// . 1 => use the index associated with in2 //
// . 2 => use whichever index is higher (tie in value) //
case MAXINDEX:
if (in1 == in2) {
return 2;
} else if (in1 > in2) {
return 1;
} else { // in1 < in2
return 0;
}
case MININDEX:
if (in1 == in2) {
return 2;
} else if (in1 < in2) {
return 1;
} else { // in1 > in2
return 0;
}
// *** END HACK ***
case LOG:
//if ( in1 <= 0 )
// throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
if( FASTMATH )
return (FastMath.log(in1)/FastMath.log(in2));
else
return (Math.log(in1)/Math.log(in2));
case LOG_NZ:
if( FASTMATH )
return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2));
else
return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
/**
* Simplified version without exception handling
*
* @param in1
* @param in2
* @return
*/
public double execute2(double in1, double in2)
{
switch(bFunc) {
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
case MAXINDEX:
return (in1 >= in2) ? 1 : 0;
case MININDEX:
return (in1 <= in2) ? 1 : 0;
default:
// For performance reasons, avoid throwing an exception
return -1;
}
}
public double execute (long in1, long in2) throws DMLRuntimeException {
switch(bFunc) {
case MAX:
case CUMMAX: return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN: return (in1 <= in2 ? in1 : in2);
case MAXINDEX: return (in1 >= in2) ? 1 : 0;
case MININDEX: return (in1 <= in2) ? 1 : 0;
case LOG:
//if ( in1 <= 0 )
// throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
if( FASTMATH )
return (FastMath.log(in1)/FastMath.log(in2));
else
return (Math.log(in1)/Math.log(in2));
case LOG_NZ:
if( FASTMATH )
return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2));
else
return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
// currently, it is used only for PRINT and STOP
public String execute (String in1)
throws DMLRuntimeException
{
switch (bFunc) {
case PRINT:
if (!DMLScript.suppressPrint2Stdout())
System.out.println(in1);
return null;
case STOP:
throw new DMLScriptException(in1);
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
}