/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.controlprogram;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.yarn.DMLAppMasterUtils;
public class ForProgramBlock extends ProgramBlock
{
protected ArrayList<Instruction> _fromInstructions;
protected ArrayList<Instruction> _toInstructions;
protected ArrayList<Instruction> _incrementInstructions;
protected ArrayList <Instruction> _exitInstructions ;
protected ArrayList<ProgramBlock> _childBlocks;
protected String[] _iterablePredicateVars; //from,to,where constants/internal vars not captured via instructions
public ForProgramBlock(Program prog, String[] iterPredVars) {
super(prog);
_exitInstructions = new ArrayList<Instruction>();
_childBlocks = new ArrayList<ProgramBlock>();
_iterablePredicateVars = iterPredVars;
}
public ArrayList<Instruction> getFromInstructions() {
return _fromInstructions;
}
public void setFromInstructions(ArrayList<Instruction> instructions) {
_fromInstructions = instructions;
}
public ArrayList<Instruction> getToInstructions() {
return _toInstructions;
}
public void setToInstructions(ArrayList<Instruction> instructions) {
_toInstructions = instructions;
}
public ArrayList<Instruction> getIncrementInstructions() {
return _incrementInstructions;
}
public void setIncrementInstructions(ArrayList<Instruction> instructions) {
_incrementInstructions = instructions;
}
public ArrayList<Instruction> getExitInstructions() {
return _exitInstructions;
}
public void setExitInstructions(ArrayList<Instruction> inst) {
_exitInstructions = inst;
}
public void addProgramBlock(ProgramBlock childBlock) {
_childBlocks.add(childBlock);
}
public ArrayList<ProgramBlock> getChildBlocks() {
return _childBlocks;
}
public void setChildBlocks(ArrayList<ProgramBlock> pbs) {
_childBlocks = pbs;
}
public String[] getIterablePredicateVars() {
return _iterablePredicateVars;
}
public void setIterablePredicateVars(String[] iterPredVars) {
_iterablePredicateVars = iterPredVars;
}
@Override
public void execute(ExecutionContext ec)
throws DMLRuntimeException
{
// add the iterable predicate variable to the variable set
String iterVarName = _iterablePredicateVars[0];
// evaluate from, to, incr only once (assumption: known at for entry)
IntObject from = executePredicateInstructions( 1, _fromInstructions, ec );
IntObject to = executePredicateInstructions( 2, _toInstructions, ec );
IntObject incr = (_incrementInstructions == null || _incrementInstructions.isEmpty()) && _iterablePredicateVars[3]==null ?
new IntObject((from.getLongValue()<=to.getLongValue()) ? 1 : -1) :
executePredicateInstructions( 3, _incrementInstructions, ec );
if ( incr.getLongValue() == 0 ) //would produce infinite loop
throw new DMLRuntimeException(this.printBlockErrorLocation() + "Expression for increment of variable '" + iterVarName + "' must evaluate to a non-zero value.");
// execute for loop
try
{
// prepare update in-place variables
UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid);
// run for loop body for each instance of predicate sequence
SequenceIterator seqIter = new SequenceIterator(iterVarName, from, to, incr);
for( IntObject iterVar : seqIter )
{
//set iteration variable
ec.setVariable(iterVarName, iterVar);
//execute all child blocks
for(int i=0 ; i < this._childBlocks.size() ; i++) {
ec.updateDebugState( i );
_childBlocks.get(i).execute(ec);
}
}
// reset update-in-place variables
resetUpdateInPlaceVariableFlags(ec, flags);
}
catch (DMLScriptException e) {
//propagate stop call
throw e;
}
catch (Exception e) {
throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating for program block", e);
}
//execute exit instructions
try {
executeInstructions(_exitInstructions, ec);
}
catch (Exception e){
throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating for exit instructions", e);
}
}
protected IntObject executePredicateInstructions( int pos, ArrayList<Instruction> instructions, ExecutionContext ec )
throws DMLRuntimeException
{
ScalarObject tmp = null;
IntObject ret = null;
try
{
if( _iterablePredicateVars[pos] != null )
{
//probe for scalar variables
Data ldat = ec.getVariable( _iterablePredicateVars[pos] );
if( ldat != null && ldat instanceof ScalarObject )
tmp = (ScalarObject)ldat;
else //handle literals
tmp = new IntObject( UtilFunctions.parseToLong(_iterablePredicateVars[pos]) );
}
else
{
if( _sb!=null )
{
if( DMLScript.isActiveAM() ) //set program block specific remote memory
DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this);
ForStatementBlock fsb = (ForStatementBlock)_sb;
Hop predHops = null;
boolean recompile = false;
if (pos == 1){
predHops = fsb.getFromHops();
recompile = fsb.requiresFromRecompilation();
}
else if (pos == 2) {
predHops = fsb.getToHops();
recompile = fsb.requiresToRecompilation();
}
else if (pos == 3){
predHops = fsb.getIncrementHops();
recompile = fsb.requiresIncrementRecompilation();
}
tmp = (IntObject) executePredicate(instructions, predHops, recompile, ValueType.INT, ec);
}
else
tmp = (IntObject) executePredicate(instructions, null, false, ValueType.INT, ec);
}
}
catch(Exception ex)
{
String predNameStr = null;
if (pos == 1) predNameStr = "from";
else if (pos == 2) predNameStr = "to";
else if (pos == 3) predNameStr = "increment";
throw new DMLRuntimeException(this.printBlockErrorLocation() +"Error evaluating '" + predNameStr + "' predicate", ex);
}
//final check of resulting int object (guaranteed to be non-null, see executePredicate)
if( tmp instanceof IntObject )
ret = (IntObject)tmp;
else //downcast to int if necessary
ret = new IntObject(tmp.getName(),tmp.getLongValue());
return ret;
}
public String printBlockErrorLocation(){
return "ERROR: Runtime error in for program block generated from for statement block between lines " + _beginLine + " and " + _endLine + " -- ";
}
/**
* Utility class for iterating over positive or negative predicate sequences.
*/
protected class SequenceIterator implements Iterator<IntObject>, Iterable<IntObject>
{
private String _varName = null;
private long _cur = -1;
private long _to = -1;
private long _incr = -1;
private boolean _inuse = false;
protected SequenceIterator(String varName, IntObject from, IntObject to, IntObject incr) {
_varName = varName;
_cur = from.getLongValue();
_to = to.getLongValue();
_incr = incr.getLongValue();
}
@Override
public boolean hasNext() {
return _incr > 0 ? _cur <= _to : _cur >= _to;
}
@Override
public IntObject next() {
IntObject ret = new IntObject( _varName, _cur );
_cur += _incr; //update current val
return ret;
}
@Override
public Iterator<IntObject> iterator() {
if( _inuse )
throw new RuntimeException("Unsupported reuse of iterator.");
_inuse = true;
return this;
}
@Override
public void remove() {
throw new RuntimeException("Unsupported remove on iterator.");
}
}
}