/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
public class FunctionCallCPInstruction extends CPInstruction
{
private String _functionName;
private String _namespace;
public String getFunctionName(){
return _functionName;
}
public String getNamespace() {
return _namespace;
}
// stores both the bound input and output parameters
private ArrayList<CPOperand> _boundInputParamOperands;
private ArrayList<String> _boundInputParamNames;
private ArrayList<String> _boundOutputParamNames;
public FunctionCallCPInstruction(String namespace, String functName, ArrayList<CPOperand> boundInParamOperands, ArrayList<String> boundInParamNames, ArrayList<String> boundOutParamNames, String istr) {
super(null, functName, istr);
_cptype = CPINSTRUCTION_TYPE.External;
_functionName = functName;
_namespace = namespace;
_boundInputParamOperands = boundInParamOperands;
_boundInputParamNames = boundInParamNames;
_boundOutputParamNames = boundOutParamNames;
}
public static FunctionCallCPInstruction parseInstruction(String str)
throws DMLRuntimeException
{
//schema: extfunct, fname, num inputs, num outputs, inputs, outputs
String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str );
String namespace = parts[1];
String functionName = parts[2];
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
ArrayList<CPOperand> boundInParamOperands = new ArrayList<CPOperand>();
ArrayList<String> boundInParamNames = new ArrayList<String>();
ArrayList<String> boundOutParamNames = new ArrayList<String>();
for (int i = 0; i < numInputs; i++) {
CPOperand operand = new CPOperand(parts[5 + i]);
boundInParamOperands.add(operand);
boundInParamNames.add(operand.getName());
}
for (int i = 0; i < numOutputs; i++) {
boundOutParamNames.add(parts[5 + numInputs + i]);
}
return new FunctionCallCPInstruction ( namespace,functionName,
boundInParamOperands, boundInParamNames, boundOutParamNames, str );
}
@Override
public Instruction preprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
//default pre-process behavior
Instruction tmp = super.preprocessInstruction(ec);
//maintain debug state (function call stack)
if( DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionEntry((FunctionCallCPInstruction) tmp);
}
return tmp;
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
if( LOG.isTraceEnabled() ){
LOG.trace("Executing instruction : " + this.toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
for( int i=0; i<fpb.getInputParams().size(); i++)
{
DataIdentifier currFormalParam = fpb.getInputParams().get(i);
String currFormalParamName = currFormalParam.getName();
Data currFormalParamValue = null;
ValueType valType = fpb.getInputParams().get(i).getValueType();
// CASE (a): default values, if call w/ less params than signature (scalars only)
if( i > _boundInputParamNames.size() )
{
String defaultVal = fpb.getInputParams().get(i).getDefaultValue();
currFormalParamValue = ec.getScalarInput(defaultVal, valType, false);
}
// CASE (b) literals or symbol table entries
else {
CPOperand operand = _boundInputParamOperands.get(i);
String varname = operand.getName();
//error handling non-existing variables
if( !operand.isLiteral() && !ec.containsVariable(varname) ) {
throw new DMLRuntimeException("Input variable '"+varname+"' not existing on call of " +
DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+").");
}
//get input matrix/frame/scalar
currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) :
ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral());
//graceful value type conversion for scalar inputs with wrong type
if( currFormalParamValue.getDataType() == DataType.SCALAR
&& currFormalParamValue.getValueType() != operand.getValueType() )
{
ScalarObject so = (ScalarObject) currFormalParamValue;
currFormalParamValue = ScalarObjectFactory
.createScalarObject(operand.getValueType(), so);
}
}
functionVariables.put(currFormalParamName, currFormalParamValue);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
HashMap<String,Boolean> pinStatus = ec.pinVariables(_boundInputParamNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
if (DMLScript.USE_ACCELERATOR) {
fn_ec.setGPUContext(ec.getGPUContext());
ec.setGPUContext(null);
fn_ec.getGPUContext().initializeThread();
}
fn_ec.setVariables(functionVariables);
// execute the function block
try {
fpb._functionName = this._functionName;
fpb._namespace = this._namespace;
fpb.execute(fn_ec);
}
catch (DMLScriptException e) {
throw e;
}
catch (Exception e){
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
LocalVariableMap retVars = fn_ec.getVariables();
// cleanup all returned variables w/o binding
Collection<String> retVarnames = new LinkedList<String>(retVars.keySet());
HashSet<String> probeVars = new HashSet<String>();
for(DataIdentifier di : fpb.getOutputParams())
probeVars.add(di.getName());
for( String var : retVarnames ) {
if( !probeVars.contains(var) ) //cleanup candidate
{
Data dat = fn_ec.removeVariable(var);
if( dat != null && dat instanceof MatrixObject )
fn_ec.cleanupMatrixObject((MatrixObject)dat);
}
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputParamNames, pinStatus);
if (DMLScript.USE_ACCELERATOR) {
ec.setGPUContext(fn_ec.getGPUContext());
fn_ec.setGPUContext(null);
ec.getGPUContext().initializeThread();
}
// add the updated binding for each return variable to the variables in original symbol table
for (int i=0; i< fpb.getOutputParams().size(); i++){
String boundVarName = _boundOutputParamNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
//cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if ( exdata != null && exdata instanceof MatrixObject && exdata != boundValue ) {
ec.cleanupMatrixObject( (MatrixObject)exdata );
}
//add/replace data in symbol table
if( boundValue instanceof MatrixObject )
((MatrixObject) boundValue).setVarName(boundVarName);
ec.setVariable(boundVarName, boundValue);
}
}
@Override
public void postprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
//maintain debug state (function call stack)
if (DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionExit( this );
}
//default post-process behavior
super.postprocessInstruction(ec);
}
@Override
public void printMe() {
LOG.debug("ExternalBuiltInFunction: " + this.toString());
}
public String getGraphString() {
return "ExtBuiltinFunc: " + _functionName;
}
public ArrayList<String> getBoundInputParamNames()
{
return _boundInputParamNames;
}
public ArrayList<String> getBoundOutputParamNames()
{
return _boundOutputParamNames;
}
public void setFunctionName(String fname)
{
//update instruction string
String oldfname = _functionName;
instString = updateInstStringFunctionName(oldfname, fname);
//set attribute
_functionName = fname;
instOpcode = fname;
}
public String updateInstStringFunctionName(String pattern, String replace)
{
//split current instruction
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
if( parts[3].equals(pattern) )
parts[3] = replace;
//construct and set modified instruction
StringBuilder sb = new StringBuilder();
for( String part : parts ) {
sb.append(part);
sb.append(Lop.OPERAND_DELIMITOR);
}
return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
}
}