/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import com.ibm.bi.dml.api.DMLScript;
import com.ibm.bi.dml.lops.Lop;
import com.ibm.bi.dml.parser.DataIdentifier;
import com.ibm.bi.dml.parser.Expression.DataType;
import com.ibm.bi.dml.parser.Expression.ValueType;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLScriptException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.controlprogram.FunctionProgramBlock;
import com.ibm.bi.dml.runtime.controlprogram.LocalVariableMap;
import com.ibm.bi.dml.runtime.controlprogram.caching.MatrixObject;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContextFactory;
import com.ibm.bi.dml.runtime.instructions.Instruction;
import com.ibm.bi.dml.runtime.instructions.InstructionUtils;
/**
*
*/
public class FunctionCallCPInstruction extends CPInstruction
{
private String _functionName;
private String _namespace;
public String getFunctionName(){
return _functionName;
}
public String getNamespace() {
return _namespace;
}
// stores both the bound input and output parameters
private ArrayList<CPOperand> _boundInputParamOperands;
private ArrayList<String> _boundInputParamNames;
private ArrayList<String> _boundOutputParamNames;
public FunctionCallCPInstruction(String namespace, String functName, ArrayList<CPOperand> boundInParamOperands, ArrayList<String> boundInParamNames, ArrayList<String> boundOutParamNames, String istr) {
super(null, functName, istr);
_cptype = CPINSTRUCTION_TYPE.External;
_functionName = functName;
_namespace = namespace;
_boundInputParamOperands = boundInParamOperands;
_boundInputParamNames = boundInParamNames;
_boundOutputParamNames = boundOutParamNames;
}
/**
* Instruction format extFunct:::[FUNCTION NAME]:::[num input params]:::[num output params]:::[list of delimited input params ]:::[list of delimited ouput params]
* These are the "bound names" for the inputs / outputs. For example, out1 = foo(in1, in2) yields
* extFunct:::foo:::2:::1:::in1:::in2:::out1
*
*/
public static Instruction parseInstruction(String str) throws DMLRuntimeException, DMLUnsupportedOperationException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str );
String namespace = parts[1];
String functionName = parts[2];
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
ArrayList<CPOperand> boundInParamOperands = new ArrayList<CPOperand>();
ArrayList<String> boundInParamNames = new ArrayList<String>();
ArrayList<String> boundOutParamNames = new ArrayList<String>();
int FIRST_PARAM_INDEX = 5;
for (int i = 0; i < numInputs; i++) {
CPOperand operand = new CPOperand(parts[FIRST_PARAM_INDEX + i]);
boundInParamOperands.add(operand);
boundInParamNames.add(operand.getName());
}
for (int i = 0; i < numOutputs; i++) {
boundOutParamNames.add(parts[FIRST_PARAM_INDEX + numInputs + i]);
}
return new FunctionCallCPInstruction ( namespace,functionName, boundInParamOperands, boundInParamNames, boundOutParamNames, str );
}
@Override
public Instruction preprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException, DMLUnsupportedOperationException
{
//default pre-process behavior
Instruction tmp = super.preprocessInstruction(ec);
//maintain debug state (function call stack)
if( DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionEntry((FunctionCallCPInstruction) tmp);
}
return tmp;
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException, DMLUnsupportedOperationException
{
if( LOG.isTraceEnabled() ){
LOG.trace("Executing instruction : " + this.toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
for( int i=0; i<fpb.getInputParams().size(); i++)
{
DataIdentifier currFormalParam = fpb.getInputParams().get(i);
String currFormalParamName = currFormalParam.getName();
Data currFormalParamValue = null;
ValueType valType = fpb.getInputParams().get(i).getValueType();
// CASE (a): default values, if call w/ less params than signature (scalars only)
if ( i > _boundInputParamNames.size()
|| (!_boundInputParamOperands.get(i).isLiteral() && ec.getVariable(_boundInputParamNames.get(i)) == null))
{
String defaultVal = fpb.getInputParams().get(i).getDefaultValue();
currFormalParamValue = ec.getScalarInput(defaultVal, valType, false);
}
// CASE (b) literals or symbol table entries
else {
CPOperand operand = _boundInputParamOperands.get(i);
if( operand.getDataType()==DataType.SCALAR )
currFormalParamValue = ec.getScalarInput(operand.getName(), operand.getValueType(), operand.isLiteral());
else
currFormalParamValue = ec.getVariable(operand.getName());
}
functionVariables.put(currFormalParamName,currFormalParamValue);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
HashMap<String,Boolean> pinStatus = ec.pinVariables(_boundInputParamNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
fn_ec.setVariables(functionVariables);
// execute the function block
try {
fpb.execute(fn_ec);
}
catch (DMLScriptException e) {
throw e;
}
catch (Exception e){
String fname = this._namespace + "::" + this._functionName;
throw new DMLRuntimeException("error executing function " + fname, e);
}
LocalVariableMap retVars = fn_ec.getVariables();
// cleanup all returned variables w/o binding
Collection<String> retVarnames = new LinkedList<String>(retVars.keySet());
HashSet<String> probeVars = new HashSet<String>();
for(DataIdentifier di : fpb.getOutputParams())
probeVars.add(di.getName());
for( String var : retVarnames ) {
if( !probeVars.contains(var) ) //cleanup candidate
{
Data dat = fn_ec.removeVariable(var);
if( dat != null && dat instanceof MatrixObject )
fn_ec.cleanupMatrixObject((MatrixObject)dat);
}
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputParamNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
for (int i=0; i< fpb.getOutputParams().size(); i++){
String boundVarName = _boundOutputParamNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLUnsupportedOperationException(boundVarName + " was not assigned a return value");
//cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if ( exdata != null && exdata instanceof MatrixObject && exdata != boundValue ) {
ec.cleanupMatrixObject( (MatrixObject)exdata );
}
//add/replace data in symbol table
if( boundValue instanceof MatrixObject )
((MatrixObject) boundValue).setVarName(boundVarName);
ec.setVariable(boundVarName, boundValue);
}
}
@Override
public void postprocessInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
//maintain debug state (function call stack)
if (DMLScript.ENABLE_DEBUG_MODE ) {
ec.handleDebugFunctionExit( this );
}
//default post-process behavior
super.postprocessInstruction(ec);
}
@Override
public void printMe() {
LOG.debug("ExternalBuiltInFunction: " + this.toString());
}
public String getGraphString() {
return "ExtBuiltinFunc: " + _functionName;
}
public ArrayList<String> getBoundInputParamNames()
{
return _boundInputParamNames;
}
public ArrayList<String> getBoundOutputParamNames()
{
return _boundOutputParamNames;
}
/**
*
* @param fname
*/
public void setFunctionName(String fname)
{
//update instruction string
String oldfname = _functionName;
instString = updateInstStringFunctionName(oldfname, fname);
//set attribute
_functionName = fname;
instOpcode = fname;
}
/**
*
* @param pattern
* @param replace
*/
public String updateInstStringFunctionName(String pattern, String replace)
{
//split current instruction
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
if( parts[3].equals(pattern) )
parts[3] = replace;
//construct and set modified instruction
StringBuilder sb = new StringBuilder();
for( String part : parts ) {
sb.append(part);
sb.append(Lop.OPERAND_DELIMITOR);
}
return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
}
}