/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.parser; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.Lop; public class WhileStatementBlock extends StatementBlock { private Hop _predicateHops; private Lop _predicateLops = null; private boolean _requiresPredicateRecompile = false; @Override public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,ConstIdentifier> constVars, boolean conditional) throws LanguageException, ParseException, IOException { if (_statements.size() > 1){ raiseValidateError("WhileStatementBlock should have only 1 statement (while statement)", conditional); } WhileStatement wstmt = (WhileStatement) _statements.get(0); ConditionalPredicate predicate = wstmt.getConditionalPredicate(); // Record original size information before loop for ALL variables // Will compare size / type info for these after loop completes // Replace variables with changed size with unknown value VariableSet origVarsBeforeBody = new VariableSet(); for (String key : ids.getVariableNames()){ DataIdentifier origId = ids.getVariable(key); DataIdentifier copyId = new DataIdentifier(origId); origVarsBeforeBody.addVariable(key, copyId); } ////////////////////////////////////////////////////////////////////////////// // FIRST PASS: process the predicate / statement blocks in the body of the for statement /////////////////////////////////////////////////////////////////////////////// //remove updated vars from constants for( String var : _updated.getVariableNames() ) if( constVars.containsKey( var ) ) constVars.remove( var ); // process the statement blocks in the body of the while statement predicate.getPredicate().validateExpression(ids.getVariables(), constVars, conditional); ArrayList<StatementBlock> body = wstmt.getBody(); _dmlProg = dmlProg; for(StatementBlock sb : body) { //always conditional ids = sb.validate(dmlProg, ids, constVars, true); constVars = sb.getConstOut(); } if (!body.isEmpty()) { _constVarsIn.putAll(body.get(0).getConstIn()); _constVarsOut.putAll(body.get(body.size()-1).getConstOut()); } // for each updated variable boolean revalidationRequired = false; for (String key : _updated.getVariableNames()) { DataIdentifier startVersion = origVarsBeforeBody.getVariable(key); DataIdentifier endVersion = ids.getVariable(key); if (startVersion != null && endVersion != null) { //handle data type change (reject) if (!startVersion.getOutput().getDataType().equals(endVersion.getOutput().getDataType())){ raiseValidateError("WhileStatementBlock has unsupported conditional data type change of variable '"+key+"' in loop body.", conditional); } //handle size change long startVersionDim1 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim1() : startVersion.getDim1(); long endVersionDim1 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim1() : endVersion.getDim1(); long startVersionDim2 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim2() : startVersion.getDim2(); long endVersionDim2 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim2() : endVersion.getDim2(); boolean sizeUnchanged = ((startVersionDim1 == endVersionDim1) && (startVersionDim2 == endVersionDim2) ); //handle sparsity change //NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed. //long startVersionNNZ = startVersion.getNnz(); //long endVersionNNZ = endVersion.getNnz(); //boolean nnzUnchanged = (startVersionNNZ == endVersionNNZ); boolean nnzUnchanged = false; // IF size has changed -- if (!sizeUnchanged || !nnzUnchanged){ revalidationRequired = true; DataIdentifier recVersion = new DataIdentifier(endVersion); if(!sizeUnchanged) recVersion.setDimensions(-1, -1); if(!nnzUnchanged) recVersion.setNnz(-1); origVarsBeforeBody.addVariable(key, recVersion); } } } // revalidation is required -- size was updated for at least 1 variable if (revalidationRequired) { // update ids to the reconciled values ids = origVarsBeforeBody; ////////////////////////////////////////////////////////////////////////////// // SECOND PASS: process the predicate / statement blocks in the body of the for statement /////////////////////////////////////////////////////////////////////////////// //remove updated vars from constants for( String var : _updated.getVariableNames() ) if( constVars.containsKey( var ) ) constVars.remove( var ); // process the statement blocks in the body of the while statement predicate.getPredicate().validateExpression(ids.getVariables(), constVars, conditional); body = wstmt.getBody(); _dmlProg = dmlProg; for(StatementBlock sb : body) { //always conditional ids = sb.validate(dmlProg, ids, constVars, true); constVars = sb.getConstOut(); } if (!body.isEmpty()) { _constVarsIn.putAll(body.get(0).getConstIn()); _constVarsOut.putAll(body.get(body.size()-1).getConstOut()); } } return ids; } public VariableSet initializeforwardLV(VariableSet activeInPassed) throws LanguageException { WhileStatement wstmt = (WhileStatement)_statements.get(0); if (_statements.size() > 1){ LOG.error(_statements.get(0).printErrorLocation() + "WhileStatementBlock should have only 1 statement (while statement)"); throw new LanguageException(_statements.get(0).printErrorLocation() + "WhileStatementBlock should have only 1 statement (while statement)"); } _read = new VariableSet(); _read.addVariables(wstmt.getConditionalPredicate().variablesRead()); _updated.addVariables(wstmt.getConditionalPredicate().variablesUpdated()); _gen = new VariableSet(); _gen.addVariables(wstmt.getConditionalPredicate().variablesRead()); VariableSet current = new VariableSet(); current.addVariables(activeInPassed); for( StatementBlock sb : wstmt.getBody() ) { current = sb.initializeforwardLV(current); // for each generated variable in this block, check variable not killed // in prior statement block in while stmt blody for (String varName : sb._gen.getVariableNames()){ // IF the variable is NOT set in the while loop PRIOR to this stmt block, // THEN needs to be generated if (!_kill.getVariableNames().contains(varName)){ _gen.addVariable(varName, sb._gen.getVariable(varName)); } } _read.addVariables(sb._read); _updated.addVariables(sb._updated); // only add kill variables for statement blocks guaranteed to execute if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){ _kill.addVariables(sb._kill); } } // set preliminary "warn" set -- variables that if used later may cause runtime error // if the loop is not executed // warnSet = (updated MINUS (updatedIfBody INTERSECT updatedElseBody)) MINUS current for (String varName : _updated.getVariableNames()){ if (!activeInPassed.containsVariable(varName)) { _warnSet.addVariable(varName, _updated.getVariable(varName)); } } // activeOut includes variables from passed live in and updated in the while body _liveOut = new VariableSet(); _liveOut.addVariables(current); _liveOut.addVariables(_updated); return _liveOut; } public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException{ WhileStatement wstmt = (WhileStatement)_statements.get(0); VariableSet lo = new VariableSet(); lo.addVariables(loPassed); // calls analyze for each statement block in while stmt body int numBlocks = wstmt.getBody().size(); for (int i = numBlocks - 1; i >= 0; i--){ lo = wstmt.getBody().get(i).analyze(lo); } VariableSet loReturn = new VariableSet(); loReturn.addVariables(lo); return loReturn; } public void setPredicateHops(Hop hops) { _predicateHops = hops; } public ArrayList<Hop> get_hops() throws HopsException { if (_hops != null && !_hops.isEmpty()){ LOG.error(this._statements.get(0).printErrorLocation() + "there should be no HOPs associated with the WhileStatementBlock"); throw new HopsException(this._statements.get(0).printErrorLocation() + "there should be no HOPs associated with the WhileStatementBlock"); } return _hops; } public Hop getPredicateHops(){ return _predicateHops; } public Lop get_predicateLops() { return _predicateLops; } public void set_predicateLops(Lop predicateLops) { _predicateLops = predicateLops; } public VariableSet analyze(VariableSet loPassed) throws LanguageException{ VariableSet predVars = new VariableSet(); predVars.addVariables(((WhileStatement)_statements.get(0)).getConditionalPredicate().variablesRead()); predVars.addVariables(((WhileStatement)_statements.get(0)).getConditionalPredicate().variablesUpdated()); VariableSet candidateLO = new VariableSet(); candidateLO.addVariables(loPassed); candidateLO.addVariables(_gen); candidateLO.addVariables(predVars); VariableSet origLiveOut = new VariableSet(); origLiveOut.addVariables(_liveOut); origLiveOut.addVariables(predVars); origLiveOut.addVariables(_gen); _liveOut = new VariableSet(); for (String name : candidateLO.getVariableNames()){ if (origLiveOut.containsVariable(name)){ _liveOut.addVariable(name, candidateLO.getVariable(name)); } } initializebackwardLV(_liveOut); // set final warnSet: remove variables NOT in live out VariableSet finalWarnSet = new VariableSet(); for (String varName : _warnSet.getVariableNames()){ if (_liveOut.containsVariable(varName)){ finalWarnSet.addVariable(varName,_warnSet.getVariable(varName)); } } _warnSet = finalWarnSet; // for now just print the warn set for (String varName : _warnSet.getVariableNames()){ LOG.warn(_warnSet.getVariable(varName).printWarningLocation() + "Initialization of " + varName + " depends on while execution"); } // Cannot remove kill variables _liveIn = new VariableSet(); _liveIn.addVariables(_liveOut); _liveIn.addVariables(_gen); VariableSet liveInReturn = new VariableSet(); liveInReturn.addVariables(_liveIn); return liveInReturn; } ///////// // materialized hops recompilation flags //// public void updatePredicateRecompilationFlag() throws HopsException { _requiresPredicateRecompile = ConfigurationManager.isDynamicRecompilation() && Recompiler.requiresRecompilation(getPredicateHops()); } public boolean requiresPredicateRecompilation() { return _requiresPredicateRecompile; } }