/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.controlprogram; import java.util.ArrayList; import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.utils.Statistics; public class FunctionProgramBlock extends ProgramBlock { public String _functionName; public String _namespace; protected ArrayList<ProgramBlock> _childBlocks; protected ArrayList<DataIdentifier> _inputParams; protected ArrayList<DataIdentifier> _outputParams; private boolean _recompileOnce = false; public FunctionProgramBlock( Program prog, ArrayList<DataIdentifier> inputParams, ArrayList<DataIdentifier> outputParams) { super(prog); _childBlocks = new ArrayList<ProgramBlock>(); _inputParams = new ArrayList<DataIdentifier>(); for (DataIdentifier id : inputParams){ _inputParams.add(new DataIdentifier(id)); } _outputParams = new ArrayList<DataIdentifier>(); for (DataIdentifier id : outputParams){ _outputParams.add(new DataIdentifier(id)); } } public ArrayList<DataIdentifier> getInputParams(){ return _inputParams; } public ArrayList<DataIdentifier> getOutputParams(){ return _outputParams; } public void addProgramBlock(ProgramBlock childBlock) { _childBlocks.add(childBlock); } public void setChildBlocks( ArrayList<ProgramBlock> pbs) { _childBlocks = pbs; } public ArrayList<ProgramBlock> getChildBlocks() { return _childBlocks; } @Override public void execute(ExecutionContext ec) throws DMLRuntimeException { //dynamically recompile entire function body (according to function inputs) try { if( ConfigurationManager.isDynamicRecompilation() && isRecompileOnce() && ParForProgramBlock.RESET_RECOMPILATION_FLAGs ) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; //note: it is important to reset the recompilation flags here // (1) it is safe to reset recompilation flags because a 'recompile_once' // function will be recompiled for every execution. // (2) without reset, there would be no benefit in recompiling the entire function LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone(); Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, true); if( DMLScript.STATISTICS ){ long t1 = System.nanoTime(); Statistics.incrementFunRecompileTime(t1-t0); Statistics.incrementFunRecompiles(); } } } catch(Exception ex) { throw new DMLRuntimeException("Error recompiling function body.", ex); } // for each program block try { for (int i=0 ; i < this._childBlocks.size() ; i++) { ec.updateDebugState(i); _childBlocks.get(i).execute(ec); } } catch (DMLScriptException e) { throw e; } catch (Exception e){ throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating function program block", e); } // check return values checkOutputParameters(ec.getVariables()); } protected void checkOutputParameters( LocalVariableMap vars ) { for( DataIdentifier diOut : _outputParams ) { String varName = diOut.getName(); Data dat = vars.get( varName ); if( dat == null ) LOG.error("Function output "+ varName +" is missing."); else if( dat.getDataType() != diOut.getDataType() ) LOG.warn("Function output "+ varName +" has wrong data type: "+dat.getDataType()+"."); else if( dat.getValueType() != diOut.getValueType() ) LOG.warn("Function output "+ varName +" has wrong value type: "+dat.getValueType()+"."); } } public void setRecompileOnce( boolean flag ) { _recompileOnce = flag; } public boolean isRecompileOnce() { return _recompileOnce; } public String printBlockErrorLocation(){ return "ERROR: Runtime error in function program block generated from function statement block between lines " + _beginLine + " and " + _endLine + " -- "; } }