/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.controlprogram; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.Lop; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.ComputationCPInstruction; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.utils.Statistics; import org.apache.sysml.yarn.DMLAppMasterUtils; public class ProgramBlock { protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName()); private static final boolean CHECK_MATRIX_SPARSITY = false; protected Program _prog; // pointer to Program this ProgramBlock is part of protected ArrayList<Instruction> _inst; //additional attributes for recompile protected StatementBlock _sb = null; protected long _tid = 0; //by default _t0 public ProgramBlock(Program prog) { _prog = prog; _inst = new ArrayList<Instruction>(); } //////////////////////////////////////////////// // getters, setters and similar functionality //////////////////////////////////////////////// public Program getProgram(){ return _prog; } public void setProgram(Program prog){ _prog = prog; } public StatementBlock getStatementBlock(){ return _sb; } public void setStatementBlock( StatementBlock sb ){ _sb = sb; } public ArrayList<Instruction> getInstructions() { return _inst; } public Instruction getInstruction(int i) { return _inst.get(i); } public void setInstructions( ArrayList<Instruction> inst ) { _inst = inst; } public void addInstruction(Instruction inst) { _inst.add(inst); } public void addInstructions(ArrayList<Instruction> inst) { _inst.addAll(inst); } public int getNumInstructions() { return _inst.size(); } public void setThreadID( long id ){ _tid = id; } ////////////////////////////////////////////////////////// // core instruction execution (program block, predicate) ////////////////////////////////////////////////////////// /** * Executes this program block (incl recompilation if required). * * @param ec execution context * @throws DMLRuntimeException if DMLRuntimeException occurs */ public void execute(ExecutionContext ec) throws DMLRuntimeException { ArrayList<Instruction> tmp = _inst; //dynamically recompile instructions if enabled and required try { if( DMLScript.isActiveAM() ) //set program block specific remote memory DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this); long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if( ConfigurationManager.isDynamicRecompilation() && _sb != null && _sb.requiresRecompilation() ) { tmp = Recompiler.recompileHopsDag( _sb, _sb.get_hops(), ec.getVariables(), null, false, true, _tid); if( MLContextProxy.isActive() ) tmp = MLContextProxy.performCleanupAfterRecompilation(tmp); } if( DMLScript.STATISTICS ){ long t1 = System.nanoTime(); Statistics.incrementHOPRecompileTime(t1-t0); if( tmp!=_inst ) Statistics.incrementHOPRecompileSB(); } } catch(Exception ex) { throw new DMLRuntimeException("Unable to recompile program block.", ex); } //actual instruction execution executeInstructions(tmp, ec); } /** * Executes given predicate instructions (incl recompilation if required) * * @param inst list of instructions * @param hops high-level operator * @param requiresRecompile true if requires recompile * @param retType value type of the return type * @param ec execution context * @return scalar object * @throws DMLRuntimeException if DMLRuntimeException occurs */ public ScalarObject executePredicate(ArrayList<Instruction> inst, Hop hops, boolean requiresRecompile, ValueType retType, ExecutionContext ec) throws DMLRuntimeException { ArrayList<Instruction> tmp = inst; //dynamically recompile instructions if enabled and required try { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if( ConfigurationManager.isDynamicRecompilation() && requiresRecompile ) { tmp = Recompiler.recompileHopsDag( hops, ec.getVariables(), null, false, true, _tid); } if( DMLScript.STATISTICS ){ long t1 = System.nanoTime(); Statistics.incrementHOPRecompileTime(t1-t0); if( tmp!=inst ) Statistics.incrementHOPRecompilePred(); } } catch(Exception ex) { throw new DMLRuntimeException("Unable to recompile predicate instructions.", ex); } //actual instruction execution return executePredicateInstructions(tmp, retType, ec); } protected void executeInstructions(ArrayList<Instruction> inst, ExecutionContext ec) throws DMLRuntimeException { for (int i = 0; i < inst.size(); i++) { //indexed access required due to dynamic add Instruction currInst = inst.get(i); //execute instruction ec.updateDebugState(i); executeSingleInstruction(currInst, ec); } } protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst, ValueType retType, ExecutionContext ec) throws DMLRuntimeException { ScalarObject ret = null; String retName = null; //execute all instructions for (int i = 0; i < inst.size(); i++) { //indexed access required due to debug mode Instruction currInst = inst.get(i); if( !isRemoveVariableInstruction(currInst) ) { //execute instruction ec.updateDebugState(i); executeSingleInstruction(currInst, ec); //get last return name if(currInst instanceof ComputationCPInstruction ) retName = ((ComputationCPInstruction) currInst).getOutputVariableName(); else if(currInst instanceof VariableCPInstruction && ((VariableCPInstruction)currInst).getOutputVariableName()!=null) retName = ((VariableCPInstruction)currInst).getOutputVariableName(); } } //get return value TODO: how do we differentiate literals and variables? ret = (ScalarObject) ec.getScalarInput(retName, retType, false); //execute rmvar instructions for (int i = 0; i < inst.size(); i++) { //indexed access required due to debug mode Instruction currInst = inst.get(i); if( isRemoveVariableInstruction(currInst) ) { ec.updateDebugState(i); executeSingleInstruction(currInst, ec); } } //check and correct scalar ret type (incl save double to int) if( ret.getValueType() != retType ) switch( retType ) { case BOOLEAN: ret = new BooleanObject(ret.getName(),ret.getBooleanValue()); break; case INT: ret = new IntObject(ret.getName(),ret.getLongValue()); break; case DOUBLE: ret = new DoubleObject(ret.getName(),ret.getDoubleValue()); break; case STRING: ret = new StringObject(ret.getName(),ret.getStringValue()); break; default: //do nothing } return ret; } private void executeSingleInstruction( Instruction currInst, ExecutionContext ec ) throws DMLRuntimeException { try { // start time measurement for statistics long t0 = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? System.nanoTime() : 0; // pre-process instruction (debug state, inst patching, listeners) Instruction tmp = currInst.preprocessInstruction( ec ); // process actual instruction tmp.processInstruction( ec ); // post-process instruction (debug) tmp.postprocessInstruction( ec ); // maintain aggregate statistics if( DMLScript.STATISTICS) { Statistics.maintainCPHeavyHitters( tmp.getExtendedOpcode(), System.nanoTime()-t0); } // optional trace information (instruction and runtime) if( LOG.isTraceEnabled() ) { long t1 = System.nanoTime(); String time = String.format("%.3f",((double)t1-t0)/1000000000); LOG.trace("Instruction: "+ tmp + " (executed in " + time + "s)."); } // optional check for correct nnz and sparse/dense representation of all // variables in symbol table (for tracking source of wrong representation) if( CHECK_MATRIX_SPARSITY ) { checkSparsity( tmp, ec.getVariables() ); } } catch (Exception e) { if (!DMLScript.ENABLE_DEBUG_MODE) { if ( e instanceof DMLScriptException) throw (DMLScriptException)e; else throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating instruction: " + currInst.toString() , e); } else { ec.handleDebugException(e); } } } protected UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext ec, long tid) throws DMLRuntimeException { if( _sb == null || _sb.getUpdateInPlaceVars().isEmpty() ) return null; ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); UpdateType[] flags = new UpdateType[varnames.size()]; for( int i=0; i<flags.length; i++ ) if( ec.getVariable(varnames.get(i)) != null ) { String varname = varnames.get(i); MatrixObject mo = ec.getMatrixObject(varname); flags[i] = mo.getUpdateType(); //create deep copy if required and if it fits in thread-local mem budget if( flags[i]==UpdateType.COPY && OptimizerUtils.getLocalMemBudget()/2 > OptimizerUtils.estimateSizeExactSparsity(mo.getMatrixCharacteristics())) { MatrixObject moNew = new MatrixObject(mo); MatrixBlock mbVar = mo.acquireRead(); moNew.acquireModify( !mbVar.isInSparseFormat() ? new MatrixBlock(mbVar) : new MatrixBlock(mbVar, MatrixBlock.DEFAULT_INPLACE_SPARSEBLOCK, true) ); moNew.setFileName(mo.getFileName()+Lop.UPDATE_INPLACE_PREFIX+tid); mo.release(); moNew.release(); moNew.setUpdateType(UpdateType.INPLACE); ec.setVariable(varname, moNew); } } return flags; } protected void resetUpdateInPlaceVariableFlags(ExecutionContext ec, UpdateType[] flags) throws DMLRuntimeException { if( flags == null ) return; //reset update-in-place flag to pre-loop status ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); for( int i=0; i<varnames.size(); i++ ) if( ec.getVariable(varnames.get(i)) != null && flags[i] !=null ) { MatrixObject mo = ec.getMatrixObject(varnames.get(i)); mo.setUpdateType(flags[i]); } } private boolean isRemoveVariableInstruction(Instruction inst) { return ( inst instanceof VariableCPInstruction && ((VariableCPInstruction)inst).isRemoveVariable() ); } private void checkSparsity( Instruction lastInst, LocalVariableMap vars ) throws DMLRuntimeException { for( String varname : vars.keySet() ) { Data dat = vars.get(varname); if( dat instanceof MatrixObject ) { MatrixObject mo = (MatrixObject)dat; if( mo.isDirty() && !mo.isPartitioned() ) { MatrixBlock mb = mo.acquireRead(); boolean sparse1 = mb.isInSparseFormat(); long nnz1 = mb.getNonZeros(); synchronized( mb ) { //potential state change mb.recomputeNonZeros(); mb.examSparsity(); } boolean sparse2 = mb.isInSparseFormat(); long nnz2 = mb.getNonZeros(); mo.release(); if( nnz1 != nnz2 ) throw new DMLRuntimeException("Matrix nnz meta data was incorrect: ("+varname+", actual="+nnz1+", expected="+nnz2+", inst="+lastInst+")"); if( sparse1 != sparse2 ) throw new DMLRuntimeException("Matrix was in wrong data representation: ("+varname+", actual="+sparse1+", expected="+sparse2 + ", nrow="+mb.getNumRows()+", ncol="+mb.getNumColumns()+", nnz="+nnz1+", inst="+lastInst+")"); } } } } /////////////////////////////////////////////////////////////////////////// // store position information for program blocks /////////////////////////////////////////////////////////////////////////// public int _beginLine, _beginColumn; public int _endLine, _endColumn; public void setBeginLine(int passed) { _beginLine = passed; } public void setBeginColumn(int passed) { _beginColumn = passed; } public void setEndLine(int passed) { _endLine = passed; } public void setEndColumn(int passed) { _endColumn = passed; } public void setAllPositions(int blp, int bcp, int elp, int ecp){ _beginLine = blp; _beginColumn = bcp; _endLine = elp; _endColumn = ecp; } public int getBeginLine() { return _beginLine; } public int getBeginColumn() { return _beginColumn; } public int getEndLine() { return _endLine; } public int getEndColumn() { return _endColumn; } public String printBlockErrorLocation(){ return "ERROR: Runtime error in program block generated from statement block between lines " + _beginLine + " and " + _endLine + " -- "; } }