/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.spark; import java.util.ArrayList; import java.util.Iterator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; import org.apache.sysml.hops.AggBinaryOp.SparkAggType; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysml.runtime.instructions.spark.functions.IsFrameBlockInRange; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDAggregateUtils; import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.data.Pair; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.util.IndexRange; import org.apache.sysml.runtime.util.UtilFunctions; public class FrameIndexingSPInstruction extends IndexingSPInstruction { /* * This class implements the frame indexing functionality inside Spark. * Example instructions: * rangeReIndex:mVar1:Var2:Var3:Var4:Var5:mVar6 * input=mVar1, output=mVar6, * bounds = (Var2,Var3,Var4,Var5) * rowindex_lower: Var2, rowindex_upper: Var3 * colindex_lower: Var4, colindex_upper: Var5 * leftIndex:mVar1:mVar2:Var3:Var4:Var5:Var6:mVar7 * triggered by "mVar1[Var3:Var4, Var5:Var6] = mVar2" * the result is stored in mVar7 * */ public FrameIndexingSPInstruction(Operator op, CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, SparkAggType aggtype, String opcode, String istr) { super(op, in, rl, ru, cl, cu, out, aggtype, opcode, istr); } public FrameIndexingSPInstruction(Operator op, CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { super(op, lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); } @Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { SparkExecutionContext sec = (SparkExecutionContext)ec; String opcode = getOpcode(); //get indexing range long rl = ec.getScalarInput(rowLower.getName(), rowLower.getValueType(), rowLower.isLiteral()).getLongValue(); long ru = ec.getScalarInput(rowUpper.getName(), rowUpper.getValueType(), rowUpper.isLiteral()).getLongValue(); long cl = ec.getScalarInput(colLower.getName(), colLower.getValueType(), colLower.isLiteral()).getLongValue(); long cu = ec.getScalarInput(colUpper.getName(), colUpper.getValueType(), colUpper.isLiteral()).getLongValue(); IndexRange ixrange = new IndexRange(rl, ru, cl, cu); //right indexing if( opcode.equalsIgnoreCase("rangeReIndex") ) { //update and check output dimensions MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input1.getName()); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); mcOut.set(ru-rl+1, cu-cl+1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); checkValidOutputDimensions(mcOut); //execute right indexing operation (partitioning-preserving if possible) JavaPairRDD<Long,FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<Long,FrameBlock> out = null; if( isPartitioningPreservingRightIndexing(mcIn, ixrange) ) { out = in1.mapPartitionsToPair( new SliceBlockPartitionFunction(ixrange, mcOut), true); } else{ out = in1.filter(new IsFrameBlockInRange(rl, ru, mcOut)) .mapToPair(new SliceBlock(ixrange, mcOut)); } //put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); //update schema of output with subset of input schema sec.getFrameObject(output.getName()).setSchema( sec.getFrameObject(input1.getName()).getSchema((int)cl, (int)cu)); } //left indexing else if ( opcode.equalsIgnoreCase("leftIndex") || opcode.equalsIgnoreCase("mapLeftIndex")) { JavaPairRDD<Long,FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() ); PartitionedBroadcast<FrameBlock> broadcastIn2 = null; JavaPairRDD<Long,FrameBlock> in2 = null; JavaPairRDD<Long,FrameBlock> out = null; //update and check output dimensions MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); MatrixCharacteristics mcLeft = ec.getMatrixCharacteristics(input1.getName()); mcOut.set(mcLeft.getRows(), mcLeft.getCols(), mcLeft.getRowsPerBlock(), mcLeft.getColsPerBlock()); checkValidOutputDimensions(mcOut); //note: always frame rhs, scalars are preprocessed via cast to 1x1 frame MatrixCharacteristics mcRight = ec.getMatrixCharacteristics(input2.getName()); //sanity check matching index range and rhs dimensions if(!mcRight.dimsKnown()) { throw new DMLRuntimeException("The right input frame dimensions are not specified for FrameIndexingSPInstruction"); } if(!(ru-rl+1 == mcRight.getRows() && cu-cl+1 == mcRight.getCols())) { throw new DMLRuntimeException("Invalid index range of leftindexing: ["+rl+":"+ru+","+cl+":"+cu+"] vs ["+mcRight.getRows()+"x"+mcRight.getCols()+"]." ); } if(opcode.equalsIgnoreCase("mapLeftIndex")) { broadcastIn2 = sec.getBroadcastForFrameVariable( input2.getName()); //partitioning-preserving mappartitions (key access required for broadcast loopkup) out = in1.mapPartitionsToPair( new LeftIndexPartitionFunction(broadcastIn2, ixrange, mcOut), true); } else { //general case // zero-out lhs in1 = in1.flatMapToPair(new ZeroOutLHS(false, ixrange, mcLeft)); // slice rhs, shift and merge with lhs in2 = sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() ) .flatMapToPair(new SliceRHSForLeftIndexing(ixrange, mcLeft)); out = FrameRDDAggregateUtils.mergeByKey(in1.union(in2)); } sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); if( broadcastIn2 != null) sec.addLineageBroadcast(output.getName(), input2.getName()); if(in2 != null) sec.addLineageRDD(output.getName(), input2.getName()); } else throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in FrameIndexingSPInstruction."); } private boolean isPartitioningPreservingRightIndexing(MatrixCharacteristics mcIn, IndexRange ixrange) { return ( mcIn.dimsKnown() && (ixrange.rowStart==1 && ixrange.rowEnd==mcIn.getRows() )); //Entire Column/s } private static void checkValidOutputDimensions(MatrixCharacteristics mcOut) throws DMLRuntimeException { if(!mcOut.dimsKnown()) { throw new DMLRuntimeException("FrameIndexingSPInstruction: The updated output dimensions are invalid: " + mcOut); } } private static class SliceRHSForLeftIndexing implements PairFlatMapFunction<Tuple2<Long,FrameBlock>, Long, FrameBlock> { private static final long serialVersionUID = 5724800998701216440L; private IndexRange _ixrange = null; private int _brlen = -1; private int _bclen = -1; private long _rlen = -1; private long _clen = -1; public SliceRHSForLeftIndexing(IndexRange ixrange, MatrixCharacteristics mcLeft) { _ixrange = ixrange; _rlen = mcLeft.getRows(); _clen = mcLeft.getCols(); _brlen = (int) Math.min(OptimizerUtils.getDefaultFrameSize(), _rlen); _bclen = (int) mcLeft.getCols(); } @Override public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> rightKV) throws Exception { Pair<Long,FrameBlock> in = SparkUtils.toIndexedFrameBlock(rightKV); ArrayList<Pair<Long,FrameBlock>> out = new ArrayList<Pair<Long,FrameBlock>>(); OperationsOnMatrixValues.performShift(in, _ixrange, _brlen, _bclen, _rlen, _clen, out); return SparkUtils.fromIndexedFrameBlock(out).iterator(); } } private static class ZeroOutLHS implements PairFlatMapFunction<Tuple2<Long,FrameBlock>, Long,FrameBlock> { private static final long serialVersionUID = -2672267231152496854L; private boolean _complement = false; private IndexRange _ixrange = null; private int _brlen = -1; private int _bclen = -1; private long _rlen = -1; public ZeroOutLHS(boolean complement, IndexRange range, MatrixCharacteristics mcLeft) { _complement = complement; _ixrange = range; _brlen = (int) OptimizerUtils.getDefaultFrameSize(); _bclen = (int) mcLeft.getCols(); _rlen = mcLeft.getRows(); } @Override public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> kv) throws Exception { ArrayList<Pair<Long,FrameBlock>> out = new ArrayList<Pair<Long,FrameBlock>>(); IndexRange curBlockRange = new IndexRange(_ixrange.rowStart, _ixrange.rowEnd, _ixrange.colStart, _ixrange.colEnd); // Global index of row (1-based) long lGblStartRow = ((kv._1.longValue()-1)/_brlen)*_brlen+1; FrameBlock zeroBlk = null; int iMaxRowsToCopy = 0; // Starting local location (0-based) of target block where to start copy. int iRowStartDest = UtilFunctions.computeCellInBlock(kv._1, _brlen); for(int iRowStartSrc = 0; iRowStartSrc<kv._2.getNumRows(); iRowStartSrc += iMaxRowsToCopy, lGblStartRow += _brlen) { IndexRange range = UtilFunctions.getSelectedRangeForZeroOut(new Pair<Long, FrameBlock>(kv._1, kv._2), _brlen, _bclen, curBlockRange, lGblStartRow-1, lGblStartRow); if(range.rowStart == -1 && range.rowEnd == -1 && range.colStart == -1 && range.colEnd == -1) { throw new Exception("Error while getting range for zero-out"); } //Maximum range of rows in target block int iMaxRows=(int) Math.min(_brlen, _rlen-lGblStartRow+1); // Maximum number of rows to be copied from source block to target. iMaxRowsToCopy = Math.min(iMaxRows, kv._2.getNumRows()-iRowStartSrc); iMaxRowsToCopy = Math.min(iMaxRowsToCopy, iMaxRows-iRowStartDest); // Zero out the applicable range in this block zeroBlk = (FrameBlock) kv._2.zeroOutOperations(new FrameBlock(), range, _complement, iRowStartSrc, iRowStartDest, iMaxRows, iMaxRowsToCopy); out.add(new Pair<Long, FrameBlock>(lGblStartRow, zeroBlk)); curBlockRange.rowStart = lGblStartRow + _brlen; iRowStartDest = UtilFunctions.computeCellInBlock(iRowStartDest+iMaxRowsToCopy+1, _brlen); } return SparkUtils.fromIndexedFrameBlock(out).iterator(); } } private static class LeftIndexPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock> { private static final long serialVersionUID = -911940376947364915L; private PartitionedBroadcast<FrameBlock> _binput; private IndexRange _ixrange = null; public LeftIndexPartitionFunction(PartitionedBroadcast<FrameBlock> binput, IndexRange ixrange, MatrixCharacteristics mc) { _binput = binput; _ixrange = ixrange; } @Override public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0) throws Exception { return new LeftIndexPartitionIterator(arg0); } private class LeftIndexPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>> { public LeftIndexPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) { super(in); } @Override protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg) throws Exception { int iNumRowsInBlock = arg._2.getNumRows(); int iNumCols = arg._2.getNumColumns(); if(!UtilFunctions.isInFrameBlockRange(arg._1(), iNumRowsInBlock, iNumCols, _ixrange)) { return arg; } // Calculate global index of left hand side block long lhs_rl = Math.max(_ixrange.rowStart, arg._1); //Math.max(_ixrange.rowStart, (arg._1-1)*iNumRowsInBlock + 1); long lhs_ru = Math.min(_ixrange.rowEnd, arg._1+iNumRowsInBlock-1); long lhs_cl = Math.max(_ixrange.colStart, 1); long lhs_cu = Math.min(_ixrange.colEnd, iNumCols); // Calculate global index of right hand side block long rhs_rl = lhs_rl - _ixrange.rowStart + 1; long rhs_ru = rhs_rl + (lhs_ru - lhs_rl); long rhs_cl = lhs_cl - _ixrange.colStart + 1; long rhs_cu = rhs_cl + (lhs_cu - lhs_cl); // Provide local zero-based index to leftIndexingOperations int lhs_lrl = (int)(lhs_rl- arg._1); int lhs_lru = (int)(lhs_ru- arg._1); int lhs_lcl = (int)lhs_cl-1; int lhs_lcu = (int)lhs_cu-1; FrameBlock ret = arg._2; int brlen = OptimizerUtils.DEFAULT_BLOCKSIZE; long rhs_rl_pb = rhs_rl; long rhs_ru_pb = Math.min(rhs_ru, (((rhs_rl-1)/brlen)+1)*brlen); while(rhs_rl_pb <= rhs_ru_pb) { // Provide global zero-based index to sliceOperations, but only for one RHS partition block at a time. FrameBlock slicedRHSMatBlock = _binput.sliceOperations(rhs_rl_pb, rhs_ru_pb, rhs_cl, rhs_cu, new FrameBlock()); // Provide local zero-based index to leftIndexingOperations int lhs_lrl_pb = (int) (lhs_lrl + (rhs_rl_pb - rhs_rl)); int lhs_lru_pb = (int) (lhs_lru + (rhs_ru_pb - rhs_ru)); ret = ret.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl_pb, lhs_lru_pb, lhs_lcl, lhs_lcu, new FrameBlock()); rhs_rl_pb = rhs_ru_pb + 1; rhs_ru_pb = Math.min(rhs_ru, rhs_ru_pb+brlen); } return new Tuple2<Long, FrameBlock>(arg._1, ret); } } } private static class SliceBlock implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> { private static final long serialVersionUID = -5270171193018691692L; private IndexRange _ixrange; public SliceBlock(IndexRange ixrange, MatrixCharacteristics mcOut) { _ixrange = ixrange; } @Override public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> kv) throws Exception { long rowindex = kv._1(); FrameBlock in = kv._2(); //prepare local index range (block guaranteed to be in range) int rl = (int) ((rowindex > _ixrange.rowStart) ? 0 : _ixrange.rowStart-rowindex); int ru = (int) ((_ixrange.rowEnd-rowindex >= in.getNumRows()) ? in.getNumRows()-1 : _ixrange.rowEnd-rowindex); //slice out the block FrameBlock out = in.sliceOperations(rl, ru, (int)(_ixrange.colStart-1), (int)(_ixrange.colEnd-1), new FrameBlock()); //return block with shifted row index long rowindex2 = (rowindex > _ixrange.rowStart) ? rowindex-_ixrange.rowStart+1 : 1; return new Tuple2<Long,FrameBlock>(rowindex2, out); } } private static class SliceBlockPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock> { private static final long serialVersionUID = -1655390518299307588L; private IndexRange _ixrange; public SliceBlockPartitionFunction(IndexRange ixrange, MatrixCharacteristics mcOut) { _ixrange = ixrange; } @Override public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0) throws Exception { return new SliceBlockPartitionIterator(arg0); } /** * NOTE: this function is only applied for slicing columns (which preserved all rows * and hence the existing partitioning). */ private class SliceBlockPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>> { public SliceBlockPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) { super(in); } @Override protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg) throws Exception { long rowindex = arg._1(); FrameBlock in = arg._2(); //slice out the block FrameBlock out = in.sliceOperations(0, in.getNumRows()-1, (int)_ixrange.colStart-1, (int)_ixrange.colEnd-1, new FrameBlock()); //return block with shifted row index return new Tuple2<Long,FrameBlock>(rowindex, out); } } } }