/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.functionobjects; import java.util.HashMap; import org.apache.commons.math3.distribution.AbstractRealDistribution; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import org.apache.commons.math3.distribution.ExponentialDistribution; import org.apache.commons.math3.distribution.FDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.TDistribution; import org.apache.commons.math3.exception.MathArithmeticException; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.util.UtilFunctions; /** * Function object for builtin function that takes a list of name=value parameters. * This class can not be instantiated elsewhere. */ public class ParameterizedBuiltin extends ValueFunction { private static final long serialVersionUID = -5966242955816522697L; public enum ParameterizedBuiltinCode { CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, TRANSFORM, TRANSFORMAPPLY, TRANSFORMDECODE }; public enum ProbabilityDistributionCode { INVALID, NORMAL, EXP, CHISQ, F, T }; public ParameterizedBuiltinCode bFunc; public ProbabilityDistributionCode distFunc; static public HashMap<String, ParameterizedBuiltinCode> String2ParameterizedBuiltinCode; static { String2ParameterizedBuiltinCode = new HashMap<String, ParameterizedBuiltinCode>(); String2ParameterizedBuiltinCode.put( "cdf", ParameterizedBuiltinCode.CDF); String2ParameterizedBuiltinCode.put( "invcdf", ParameterizedBuiltinCode.INVCDF); String2ParameterizedBuiltinCode.put( "rmempty", ParameterizedBuiltinCode.RMEMPTY); String2ParameterizedBuiltinCode.put( "replace", ParameterizedBuiltinCode.REPLACE); String2ParameterizedBuiltinCode.put( "rexpand", ParameterizedBuiltinCode.REXPAND); String2ParameterizedBuiltinCode.put( "transform", ParameterizedBuiltinCode.TRANSFORM); String2ParameterizedBuiltinCode.put( "transformapply", ParameterizedBuiltinCode.TRANSFORMAPPLY); String2ParameterizedBuiltinCode.put( "transformdecode", ParameterizedBuiltinCode.TRANSFORMDECODE); } static public HashMap<String, ProbabilityDistributionCode> String2DistCode; static { String2DistCode = new HashMap<String,ProbabilityDistributionCode>(); String2DistCode.put("normal" , ProbabilityDistributionCode.NORMAL); String2DistCode.put("exp" , ProbabilityDistributionCode.EXP); String2DistCode.put("chisq" , ProbabilityDistributionCode.CHISQ); String2DistCode.put("f" , ProbabilityDistributionCode.F); String2DistCode.put("t" , ProbabilityDistributionCode.T); } // We should create one object for every builtin function that we support private static ParameterizedBuiltin normalObj = null, expObj = null, chisqObj = null, fObj = null, tObj = null; private static ParameterizedBuiltin inormalObj = null, iexpObj = null, ichisqObj = null, ifObj = null, itObj = null; private ParameterizedBuiltin(ParameterizedBuiltinCode bf) { bFunc = bf; distFunc = ProbabilityDistributionCode.INVALID; } private ParameterizedBuiltin(ParameterizedBuiltinCode bf, ProbabilityDistributionCode dist) { bFunc = bf; distFunc = dist; } public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str) throws DMLRuntimeException { return getParameterizedBuiltinFnObject (str, null); } public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str, String str2) throws DMLRuntimeException { ParameterizedBuiltinCode code = String2ParameterizedBuiltinCode.get(str); switch ( code ) { case CDF: // str2 will point the appropriate distribution ProbabilityDistributionCode dcode = String2DistCode.get(str2.toLowerCase()); switch(dcode) { case NORMAL: if ( normalObj == null ) normalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode); return normalObj; case EXP: if ( expObj == null ) expObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode); return expObj; case CHISQ: if ( chisqObj == null ) chisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode); return chisqObj; case F: if ( fObj == null ) fObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode); return fObj; case T: if ( tObj == null ) tObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode); return tObj; default: throw new DMLRuntimeException("Invalid distribution code: " + dcode); } case INVCDF: // str2 will point the appropriate distribution ProbabilityDistributionCode distcode = String2DistCode.get(str2.toLowerCase()); switch(distcode) { case NORMAL: if ( inormalObj == null ) inormalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode); return inormalObj; case EXP: if ( iexpObj == null ) iexpObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode); return iexpObj; case CHISQ: if ( ichisqObj == null ) ichisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode); return ichisqObj; case F: if ( ifObj == null ) ifObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode); return ifObj; case T: if ( itObj == null ) itObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode); return itObj; default: throw new DMLRuntimeException("Invalid distribution code: " + distcode); } case RMEMPTY: return new ParameterizedBuiltin(ParameterizedBuiltinCode.RMEMPTY); case REPLACE: return new ParameterizedBuiltin(ParameterizedBuiltinCode.REPLACE); case REXPAND: return new ParameterizedBuiltin(ParameterizedBuiltinCode.REXPAND); case TRANSFORM: return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORM); case TRANSFORMAPPLY: return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMAPPLY); case TRANSFORMDECODE: return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMDECODE); default: throw new DMLRuntimeException("Invalid parameterized builtin code: " + code); } } public double execute(HashMap<String,String> params) throws DMLRuntimeException { switch(bFunc) { case CDF: case INVCDF: switch(distFunc) { case NORMAL: case EXP: case CHISQ: case F: case T: return computeFromDistribution(distFunc, params, (bFunc==ParameterizedBuiltinCode.INVCDF)); default: throw new DMLRuntimeException("Unsupported distribution (" + distFunc + ")."); } default: throw new DMLRuntimeException("ParameterizedBuiltin.execute(): Unknown operation: " + bFunc); } } /** * Helper function to compute distribution-specific cdf (both lowertail and uppertail) and inverse cdf. * * @param dcode probablility distribution code * @param params map of parameters * @param inverse true if inverse * @return cdf or inverse cdf * @throws MathArithmeticException if MathArithmeticException occurs * @throws DMLRuntimeException if DMLRuntimeException occurs */ private double computeFromDistribution (ProbabilityDistributionCode dcode, HashMap<String,String> params, boolean inverse ) throws MathArithmeticException, DMLRuntimeException { // given value is "quantile" when inverse=false, and it is "probability" when inverse=true double val = Double.parseDouble(params.get("target")); boolean lowertail = true; if(params.get("lower.tail") != null) { lowertail = Boolean.parseBoolean(params.get("lower.tail")); } AbstractRealDistribution distFunction = null; switch(dcode) { case NORMAL: double mean = 0.0, sd = 1.0; // default values for mean and sd String mean_s = params.get("mean"), sd_s = params.get("sd"); if(mean_s != null) mean = Double.parseDouble(mean_s); if(sd_s != null) sd = Double.parseDouble(sd_s); if ( sd <= 0 ) throw new DMLRuntimeException("Standard deviation for Normal distribution must be positive (" + sd + ")"); distFunction = new NormalDistribution(mean, sd); break; case EXP: double exp_rate = 1.0; // default value for 1/mean or rate if(params.get("rate") != null) exp_rate = Double.parseDouble(params.get("rate")); if ( exp_rate <= 0 ) { throw new DMLRuntimeException("Rate for Exponential distribution must be positive (" + exp_rate + ")"); } // For exponential distribution: mean = 1/rate distFunction = new ExponentialDistribution(1.0/exp_rate); break; case CHISQ: if ( params.get("df") == null ) { throw new DMLRuntimeException("" + "Degrees of freedom must be specified for chi-squared distribution " + "(e.g., q=qchisq(0.5, df=20); p=pchisq(target=q, df=1.2))"); } int df = UtilFunctions.parseToInt(params.get("df")); if ( df <= 0 ) { throw new DMLRuntimeException("Degrees of Freedom for chi-squared distribution must be positive (" + df + ")"); } distFunction = new ChiSquaredDistribution(df); break; case F: if ( params.get("df1") == null || params.get("df2") == null ) { throw new DMLRuntimeException("" + "Degrees of freedom must be specified for F distribution " + "(e.g., q = qf(target=0.5, df1=20, df2=30); p=pf(target=q, df1=20, df2=30))"); } int df1 = UtilFunctions.parseToInt(params.get("df1")); int df2 = UtilFunctions.parseToInt(params.get("df2")); if ( df1 <= 0 || df2 <= 0) { throw new DMLRuntimeException("Degrees of Freedom for F distribution must be positive (" + df1 + "," + df2 + ")"); } distFunction = new FDistribution(df1, df2); break; case T: if ( params.get("df") == null ) { throw new DMLRuntimeException("" + "Degrees of freedom is needed to compute probabilities from t distribution " + "(e.g., q = qt(target=0.5, df=10); p = pt(target=q, df=10))"); } int t_df = UtilFunctions.parseToInt(params.get("df")); if ( t_df <= 0 ) { throw new DMLRuntimeException("Degrees of Freedom for t distribution must be positive (" + t_df + ")"); } distFunction = new TDistribution(t_df); break; default: throw new DMLRuntimeException("Invalid distribution code: " + dcode); } double ret = Double.NaN; if(inverse) { // inverse cdf ret = distFunction.inverseCumulativeProbability(val); } else if(lowertail) { // cdf (lowertail) ret = distFunction.cumulativeProbability(val); } else { // cdf (upper tail) // TODO: more accurate distribution-specific computation of upper tail probabilities ret = 1.0 - distFunction.cumulativeProbability(val); } return ret; } }