/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.functionobjects;
import java.util.HashMap;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
/**
* Class with pre-defined set of objects. This class can not be instantiated elsewhere.
*
* Notes on commons.math FastMath:
* * FastMath uses lookup tables and interpolation instead of native calls.
* * The memory overhead for those tables is roughly 48KB in total (acceptable)
* * Micro and application benchmarks showed significantly (30%-3x) performance improvements
* for most operations; without loss of accuracy.
* * atan / sqrt were 20% slower in FastMath and hence, we use Math there
* * round / abs were equivalent in FastMath and hence, we use Math there
* * Finally, there is just one argument against FastMath - The comparison heavily depends
* on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
* which makes this operation very efficient; as soon as other operations like log/exp are
* similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
*
*/
public class Builtin extends ValueFunction
{
private static final long serialVersionUID = 3836744687789840574L;
public enum BuiltinCode { SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP };
public BuiltinCode bFunc;
private static final boolean FASTMATH = true;
static public HashMap<String, BuiltinCode> String2BuiltinCode;
static {
String2BuiltinCode = new HashMap<String, BuiltinCode>();
String2BuiltinCode.put( "sin" , BuiltinCode.SIN);
String2BuiltinCode.put( "cos" , BuiltinCode.COS);
String2BuiltinCode.put( "tan" , BuiltinCode.TAN);
String2BuiltinCode.put( "asin" , BuiltinCode.ASIN);
String2BuiltinCode.put( "acos" , BuiltinCode.ACOS);
String2BuiltinCode.put( "atan" , BuiltinCode.ATAN);
String2BuiltinCode.put( "log" , BuiltinCode.LOG);
String2BuiltinCode.put( "log_nz" , BuiltinCode.LOG_NZ);
String2BuiltinCode.put( "min" , BuiltinCode.MIN);
String2BuiltinCode.put( "max" , BuiltinCode.MAX);
String2BuiltinCode.put( "maxindex", BuiltinCode.MAXINDEX);
String2BuiltinCode.put( "minindex", BuiltinCode.MININDEX);
String2BuiltinCode.put( "abs" , BuiltinCode.ABS);
String2BuiltinCode.put( "sign" , BuiltinCode.SIGN);
String2BuiltinCode.put( "sqrt" , BuiltinCode.SQRT);
String2BuiltinCode.put( "exp" , BuiltinCode.EXP);
String2BuiltinCode.put( "plogp" , BuiltinCode.PLOGP);
String2BuiltinCode.put( "print" , BuiltinCode.PRINT);
String2BuiltinCode.put( "printf" , BuiltinCode.PRINTF);
String2BuiltinCode.put( "nrow" , BuiltinCode.NROW);
String2BuiltinCode.put( "ncol" , BuiltinCode.NCOL);
String2BuiltinCode.put( "length" , BuiltinCode.LENGTH);
String2BuiltinCode.put( "round" , BuiltinCode.ROUND);
String2BuiltinCode.put( "stop" , BuiltinCode.STOP);
String2BuiltinCode.put( "ceil" , BuiltinCode.CEIL);
String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR);
String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD);
String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);
String2BuiltinCode.put( "ucummax", BuiltinCode.CUMMAX);
String2BuiltinCode.put( "inverse", BuiltinCode.INVERSE);
String2BuiltinCode.put( "sprop", BuiltinCode.SPROP);
String2BuiltinCode.put( "sigmoid", BuiltinCode.SIGMOID);
String2BuiltinCode.put( "sel+", BuiltinCode.SELP);
}
// We should create one object for every builtin function that we support
private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null;
private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null;
private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj;
private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null;
private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null;
private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;
private Builtin(BuiltinCode bf) {
bFunc = bf;
}
public BuiltinCode getBuiltinCode() {
return bFunc;
}
public static Builtin getBuiltinFnObject (String str)
{
BuiltinCode code = String2BuiltinCode.get(str);
return getBuiltinFnObject( code );
}
public static Builtin getBuiltinFnObject(BuiltinCode code)
{
if ( code == null )
return null;
switch ( code ) {
case SIN:
if ( sinObj == null )
sinObj = new Builtin(BuiltinCode.SIN);
return sinObj;
case COS:
if ( cosObj == null )
cosObj = new Builtin(BuiltinCode.COS);
return cosObj;
case TAN:
if ( tanObj == null )
tanObj = new Builtin(BuiltinCode.TAN);
return tanObj;
case ASIN:
if ( asinObj == null )
asinObj = new Builtin(BuiltinCode.ASIN);
return asinObj;
case ACOS:
if ( acosObj == null )
acosObj = new Builtin(BuiltinCode.ACOS);
return acosObj;
case ATAN:
if ( atanObj == null )
atanObj = new Builtin(BuiltinCode.ATAN);
return atanObj;
case LOG:
if ( logObj == null )
logObj = new Builtin(BuiltinCode.LOG);
return logObj;
case LOG_NZ:
if ( lognzObj == null )
lognzObj = new Builtin(BuiltinCode.LOG_NZ);
return lognzObj;
case MAX:
if ( maxObj == null )
maxObj = new Builtin(BuiltinCode.MAX);
return maxObj;
case MAXINDEX:
if ( maxindexObj == null )
maxindexObj = new Builtin(BuiltinCode.MAXINDEX);
return maxindexObj;
case MIN:
if ( minObj == null )
minObj = new Builtin(BuiltinCode.MIN);
return minObj;
case MININDEX:
if ( minindexObj == null )
minindexObj = new Builtin(BuiltinCode.MININDEX);
return minindexObj;
case ABS:
if ( absObj == null )
absObj = new Builtin(BuiltinCode.ABS);
return absObj;
case SIGN:
if ( signObj == null )
signObj = new Builtin(BuiltinCode.SIGN);
return signObj;
case SQRT:
if ( sqrtObj == null )
sqrtObj = new Builtin(BuiltinCode.SQRT);
return sqrtObj;
case EXP:
if ( expObj == null )
expObj = new Builtin(BuiltinCode.EXP);
return expObj;
case PLOGP:
if ( plogpObj == null )
plogpObj = new Builtin(BuiltinCode.PLOGP);
return plogpObj;
case PRINT:
if ( printObj == null )
printObj = new Builtin(BuiltinCode.PRINT);
return printObj;
case PRINTF:
if (printfObj == null) {
printfObj = new Builtin(BuiltinCode.PRINTF);
}
return printfObj;
case NROW:
if ( nrowObj == null )
nrowObj = new Builtin(BuiltinCode.NROW);
return nrowObj;
case NCOL:
if ( ncolObj == null )
ncolObj = new Builtin(BuiltinCode.NCOL);
return ncolObj;
case LENGTH:
if ( lengthObj == null )
lengthObj = new Builtin(BuiltinCode.LENGTH);
return lengthObj;
case ROUND:
if ( roundObj == null )
roundObj = new Builtin(BuiltinCode.ROUND);
return roundObj;
case CEIL:
if ( ceilObj == null )
ceilObj = new Builtin(BuiltinCode.CEIL);
return ceilObj;
case FLOOR:
if ( floorObj == null )
floorObj = new Builtin(BuiltinCode.FLOOR);
return floorObj;
case CUMSUM:
if ( cumsumObj == null )
cumsumObj = new Builtin(BuiltinCode.CUMSUM);
return cumsumObj;
case CUMPROD:
if ( cumprodObj == null )
cumprodObj = new Builtin(BuiltinCode.CUMPROD);
return cumprodObj;
case CUMMIN:
if ( cumminObj == null )
cumminObj = new Builtin(BuiltinCode.CUMMIN);
return cumminObj;
case CUMMAX:
if ( cummaxObj == null )
cummaxObj = new Builtin(BuiltinCode.CUMMAX);
return cummaxObj;
case INVERSE:
if ( inverseObj == null )
inverseObj = new Builtin(BuiltinCode.INVERSE);
return inverseObj;
case STOP:
if ( stopObj == null )
stopObj = new Builtin(BuiltinCode.STOP);
return stopObj;
case SPROP:
if ( spropObj == null )
spropObj = new Builtin(BuiltinCode.SPROP);
return spropObj;
case SIGMOID:
if ( sigmoidObj == null )
sigmoidObj = new Builtin(BuiltinCode.SIGMOID);
return sigmoidObj;
case SELP:
if ( selpObj == null )
selpObj = new Builtin(BuiltinCode.SELP);
return selpObj;
default:
// Unknown code --> return null
return null;
}
}
@Override
public double execute (double in)
throws DMLRuntimeException
{
switch(bFunc) {
case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in);
case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in);
case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in);
case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in);
case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in);
case ATAN: return Math.atan(in); //faster in Math
case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in);
case LOG: return FASTMATH ? FastMath.log(in) : Math.log(in);
case LOG_NZ: return (in==0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);
case ABS: return Math.abs(in); //no need for FastMath
case SIGN: return FASTMATH ? FastMath.signum(in) : Math.signum(in);
case SQRT: return Math.sqrt(in); //faster in Math
case EXP: return FASTMATH ? FastMath.exp(in) : Math.exp(in);
case ROUND: return Math.round(in); //no need for FastMath
case PLOGP:
if (in == 0.0)
return 0.0;
else if (in < 0)
return Double.NaN;
else
return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));
case SPROP:
//sample proportion: P*(1-P)
return in * (1 - in);
case SIGMOID:
//sigmoid: 1/(1+exp(-x))
return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));
case SELP:
//select positive: x*(x>0)
return (in > 0) ? in : 0;
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
@Override
public double execute (long in) throws DMLRuntimeException {
return execute((double)in);
}
/*
* Builtin functions with two inputs
*/
@Override
public double execute (double in1, double in2) throws DMLRuntimeException {
switch(bFunc) {
/*
* Arithmetic relational operators (==, !=, <=, >=) must be instead of
* <code>Double.compare()</code> due to the inconsistencies in the way
* NaN and -0.0 are handled. The behavior of methods in
* <code>Double</code> class are designed mainly to make Java
* collections work properly. For more details, see the help for
* <code>Double.equals()</code> and <code>Double.comapreTo()</code>.
*/
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// rowIndexMax() and its siblings require comparing four values, but
// the aggregation API only allows two values. So the execute()
// method receives as its argument the two cell values to be
// compared and performs just the value part of the comparison. We
// return an integer cast down to a double, since the aggregation
// API doesn't have any way to return anything but a double. The
// integer returned takes on three posssible values: //
// . 0 => keep the index associated with in1 //
// . 1 => use the index associated with in2 //
// . 2 => use whichever index is higher (tie in value) //
case MAXINDEX:
if (in1 == in2) {
return 2;
} else if (in1 > in2) {
return 1;
} else { // in1 < in2
return 0;
}
case MININDEX:
if (in1 == in2) {
return 2;
} else if (in1 < in2) {
return 1;
} else { // in1 > in2
return 0;
}
// *** END HACK ***
case LOG:
//if ( in1 <= 0 )
// throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
if( FASTMATH )
return (FastMath.log(in1)/FastMath.log(in2));
else
return (Math.log(in1)/Math.log(in2));
case LOG_NZ:
if( FASTMATH )
return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2));
else
return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
/**
* Simplified version without exception handling
*
* @param in1 double 1
* @param in2 double 2
* @return result
*/
public double execute2(double in1, double in2)
{
switch(bFunc) {
case MAX:
case CUMMAX:
//return (Double.compare(in1, in2) >= 0 ? in1 : in2);
return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN:
//return (Double.compare(in1, in2) <= 0 ? in1 : in2);
return (in1 <= in2 ? in1 : in2);
case MAXINDEX:
return (in1 >= in2) ? 1 : 0;
case MININDEX:
return (in1 <= in2) ? 1 : 0;
default:
// For performance reasons, avoid throwing an exception
return -1;
}
}
@Override
public double execute (long in1, long in2) throws DMLRuntimeException {
switch(bFunc) {
case MAX:
case CUMMAX: return (in1 >= in2 ? in1 : in2);
case MIN:
case CUMMIN: return (in1 <= in2 ? in1 : in2);
case MAXINDEX: return (in1 >= in2) ? 1 : 0;
case MININDEX: return (in1 <= in2) ? 1 : 0;
case LOG:
//if ( in1 <= 0 )
// throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
if( FASTMATH )
return (FastMath.log(in1)/FastMath.log(in2));
else
return (Math.log(in1)/Math.log(in2));
case LOG_NZ:
if( FASTMATH )
return (in1==0) ? 0 : (FastMath.log(in1)/FastMath.log(in2));
else
return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
@Override
public String execute (String in1)
throws DMLRuntimeException
{
switch (bFunc) {
case PRINT:
if (!DMLScript.suppressPrint2Stdout())
System.out.println(in1);
return null;
case PRINTF:
if (!DMLScript.suppressPrint2Stdout())
System.out.println(in1);
return null;
case STOP:
throw new DMLScriptException(in1);
default:
throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
}
}
}