/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.ForStatementBlock; import org.apache.sysml.parser.IndexedIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; import org.apache.sysml.parser.WhileStatementBlock; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; /** * Rule: Insert checkpointing operations for caching purposes. Currently, we * follow a heuristic of checkpointing (1) all variables used read-only in loops, * and (2) intermediates used by multiple consumers. * * TODO (2) implement injection for multiple consumers (local and global). * */ public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRule { private boolean _checkCtx = false; public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext) { _checkCtx = checkParForContext; } @Override public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); if( !OptimizerUtils.isSparkExecutionMode() ) { ret.add(sb); // nothing to do here return ret; //return original statement block } //1) We currently add checkpoint operations without information about the global program structure, //this assumes that redundant checkpointing is prevented at runtime level (instruction-level) //2) Also, we do not take size information into account right now. This means that all candidates //are checkpointed even if they are only used by CP operations. int blocksize = status.getBlocksize(); //block size set by reblock rewrite //apply rewrite for while, for, and parfor (the decision for parfor loop bodies is deferred until parfor //optimization because otherwise we would prevent remote parfor) if( (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) //incl parfor && (_checkCtx ? !status.isInParforContext() : true) ) { //step 1: determine checkpointing candidates ArrayList<String> candidates = new ArrayList<String>(); VariableSet read = sb.variablesRead(); VariableSet updated = sb.variablesUpdated(); for( String rvar : read.getVariableNames() ) if( !updated.containsVariable(rvar) && read.getVariable(rvar).getDataType()==DataType.MATRIX ) candidates.add(rvar); //step 2: insert statement block with checkpointing operations if( !candidates.isEmpty() ) //existing candidates { StatementBlock sb0 = new StatementBlock(); sb0.setDMLProg(sb.getDMLProg()); sb0.setAllPositions(sb.getFilename(), sb.getBeginLine(), sb.getBeginColumn(), sb.getEndLine(), sb.getEndColumn()); ArrayList<Hop> hops = new ArrayList<Hop>(); VariableSet livein = new VariableSet(); VariableSet liveout = new VariableSet(); for( String var : candidates ) { DataIdentifier dat = read.getVariable(var); long dim1 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim1() : dat.getDim1(); long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2(); DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD, dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize); tread.setRequiresCheckpoint( true ); DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null); HopRewriteUtils.setOutputParameters(twrite, dim1, dim2, blocksize, blocksize, dat.getNnz()); hops.add(twrite); livein.addVariable(var, read.getVariable(var)); liveout.addVariable(var, read.getVariable(var)); } sb0.set_hops(hops); sb0.setLiveIn(livein); sb0.setLiveOut(liveout); ret.add(sb0); //maintain rewrite status status.setInjectedCheckpoints(); } } //add original statement block to end ret.add(sb); return ret; } }