/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.cp; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import org.apache.sysml.api.DMLScript; import org.apache.sysml.lops.Lop; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.InstructionUtils; public class FunctionCallCPInstruction extends CPInstruction { private String _functionName; private String _namespace; public String getFunctionName(){ return _functionName; } public String getNamespace() { return _namespace; } // stores both the bound input and output parameters private ArrayList<CPOperand> _boundInputParamOperands; private ArrayList<String> _boundInputParamNames; private ArrayList<String> _boundOutputParamNames; public FunctionCallCPInstruction(String namespace, String functName, ArrayList<CPOperand> boundInParamOperands, ArrayList<String> boundInParamNames, ArrayList<String> boundOutParamNames, String istr) { super(null, functName, istr); _cptype = CPINSTRUCTION_TYPE.External; _functionName = functName; _namespace = namespace; _boundInputParamOperands = boundInParamOperands; _boundInputParamNames = boundInParamNames; _boundOutputParamNames = boundOutParamNames; } public static FunctionCallCPInstruction parseInstruction(String str) throws DMLRuntimeException { //schema: extfunct, fname, num inputs, num outputs, inputs, outputs String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str ); String namespace = parts[1]; String functionName = parts[2]; int numInputs = Integer.valueOf(parts[3]); int numOutputs = Integer.valueOf(parts[4]); ArrayList<CPOperand> boundInParamOperands = new ArrayList<CPOperand>(); ArrayList<String> boundInParamNames = new ArrayList<String>(); ArrayList<String> boundOutParamNames = new ArrayList<String>(); for (int i = 0; i < numInputs; i++) { CPOperand operand = new CPOperand(parts[5 + i]); boundInParamOperands.add(operand); boundInParamNames.add(operand.getName()); } for (int i = 0; i < numOutputs; i++) { boundOutParamNames.add(parts[5 + numInputs + i]); } return new FunctionCallCPInstruction ( namespace,functionName, boundInParamOperands, boundInParamNames, boundOutParamNames, str ); } @Override public Instruction preprocessInstruction(ExecutionContext ec) throws DMLRuntimeException { //default pre-process behavior Instruction tmp = super.preprocessInstruction(ec); //maintain debug state (function call stack) if( DMLScript.ENABLE_DEBUG_MODE ) { ec.handleDebugFunctionEntry((FunctionCallCPInstruction) tmp); } return tmp; } @Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { if( LOG.isTraceEnabled() ){ LOG.trace("Executing instruction : " + this.toString()); } // get the function program block (stored in the Program object) FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName); // create bindings to formal parameters for given function call // These are the bindings passed to the FunctionProgramBlock for function execution LocalVariableMap functionVariables = new LocalVariableMap(); for( int i=0; i<fpb.getInputParams().size(); i++) { DataIdentifier currFormalParam = fpb.getInputParams().get(i); String currFormalParamName = currFormalParam.getName(); Data currFormalParamValue = null; ValueType valType = fpb.getInputParams().get(i).getValueType(); // CASE (a): default values, if call w/ less params than signature (scalars only) if( i > _boundInputParamNames.size() ) { String defaultVal = fpb.getInputParams().get(i).getDefaultValue(); currFormalParamValue = ec.getScalarInput(defaultVal, valType, false); } // CASE (b) literals or symbol table entries else { CPOperand operand = _boundInputParamOperands.get(i); String varname = operand.getName(); //error handling non-existing variables if( !operand.isLiteral() && !ec.containsVariable(varname) ) { throw new DMLRuntimeException("Input variable '"+varname+"' not existing on call of " + DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+")."); } //get input matrix/frame/scalar currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) : ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral()); //graceful value type conversion for scalar inputs with wrong type if( currFormalParamValue.getDataType() == DataType.SCALAR && currFormalParamValue.getValueType() != operand.getValueType() ) { ScalarObject so = (ScalarObject) currFormalParamValue; currFormalParamValue = ScalarObjectFactory .createScalarObject(operand.getValueType(), so); } } functionVariables.put(currFormalParamName, currFormalParamValue); } // Pin the input variables so that they do not get deleted // from pb's symbol table at the end of execution of function HashMap<String,Boolean> pinStatus = ec.pinVariables(_boundInputParamNames); // Create a symbol table under a new execution context for the function invocation, // and copy the function arguments into the created table. ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram()); if (DMLScript.USE_ACCELERATOR) { fn_ec.setGPUContext(ec.getGPUContext()); ec.setGPUContext(null); fn_ec.getGPUContext().initializeThread(); } fn_ec.setVariables(functionVariables); // execute the function block try { fpb._functionName = this._functionName; fpb._namespace = this._namespace; fpb.execute(fn_ec); } catch (DMLScriptException e) { throw e; } catch (Exception e){ String fname = DMLProgram.constructFunctionKey(_namespace, _functionName); throw new DMLRuntimeException("error executing function " + fname, e); } LocalVariableMap retVars = fn_ec.getVariables(); // cleanup all returned variables w/o binding Collection<String> retVarnames = new LinkedList<String>(retVars.keySet()); HashSet<String> probeVars = new HashSet<String>(); for(DataIdentifier di : fpb.getOutputParams()) probeVars.add(di.getName()); for( String var : retVarnames ) { if( !probeVars.contains(var) ) //cleanup candidate { Data dat = fn_ec.removeVariable(var); if( dat != null && dat instanceof MatrixObject ) fn_ec.cleanupMatrixObject((MatrixObject)dat); } } // Unpin the pinned variables ec.unpinVariables(_boundInputParamNames, pinStatus); if (DMLScript.USE_ACCELERATOR) { ec.setGPUContext(fn_ec.getGPUContext()); fn_ec.setGPUContext(null); ec.getGPUContext().initializeThread(); } // add the updated binding for each return variable to the variables in original symbol table for (int i=0; i< fpb.getOutputParams().size(); i++){ String boundVarName = _boundOutputParamNames.get(i); Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName()); if (boundValue == null) throw new DMLRuntimeException(boundVarName + " was not assigned a return value"); //cleanup existing data bound to output variable name Data exdata = ec.removeVariable(boundVarName); if ( exdata != null && exdata instanceof MatrixObject && exdata != boundValue ) { ec.cleanupMatrixObject( (MatrixObject)exdata ); } //add/replace data in symbol table if( boundValue instanceof MatrixObject ) ((MatrixObject) boundValue).setVarName(boundVarName); ec.setVariable(boundVarName, boundValue); } } @Override public void postprocessInstruction(ExecutionContext ec) throws DMLRuntimeException { //maintain debug state (function call stack) if (DMLScript.ENABLE_DEBUG_MODE ) { ec.handleDebugFunctionExit( this ); } //default post-process behavior super.postprocessInstruction(ec); } @Override public void printMe() { LOG.debug("ExternalBuiltInFunction: " + this.toString()); } public String getGraphString() { return "ExtBuiltinFunc: " + _functionName; } public ArrayList<String> getBoundInputParamNames() { return _boundInputParamNames; } public ArrayList<String> getBoundOutputParamNames() { return _boundOutputParamNames; } public void setFunctionName(String fname) { //update instruction string String oldfname = _functionName; instString = updateInstStringFunctionName(oldfname, fname); //set attribute _functionName = fname; instOpcode = fname; } public String updateInstStringFunctionName(String pattern, String replace) { //split current instruction String[] parts = instString.split(Lop.OPERAND_DELIMITOR); if( parts[3].equals(pattern) ) parts[3] = replace; //construct and set modified instruction StringBuilder sb = new StringBuilder(); for( String part : parts ) { sb.append(part); sb.append(Lop.OPERAND_DELIMITOR); } return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() ); } }