/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.runtime.instructions.spark; import java.util.ArrayList; import java.util.Iterator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; import com.ibm.bi.dml.hops.AggBinaryOp.SparkAggType; import com.ibm.bi.dml.runtime.DMLRuntimeException; import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException; import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext; import com.ibm.bi.dml.runtime.controlprogram.context.SparkExecutionContext; import com.ibm.bi.dml.runtime.instructions.Instruction; import com.ibm.bi.dml.runtime.instructions.InstructionUtils; import com.ibm.bi.dml.runtime.instructions.cp.CPOperand; import com.ibm.bi.dml.runtime.instructions.spark.data.LazyIterableIterator; import com.ibm.bi.dml.runtime.instructions.spark.data.PartitionedBroadcastMatrix; import com.ibm.bi.dml.runtime.instructions.spark.functions.IsBlockInRange; import com.ibm.bi.dml.runtime.instructions.spark.utils.RDDAggregateUtils; import com.ibm.bi.dml.runtime.instructions.spark.utils.SparkUtils; import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics; import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock; import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes; import com.ibm.bi.dml.runtime.matrix.data.OperationsOnMatrixValues; import com.ibm.bi.dml.runtime.matrix.mapred.IndexedMatrixValue; import com.ibm.bi.dml.runtime.matrix.operators.Operator; import com.ibm.bi.dml.runtime.matrix.operators.SimpleOperator; import com.ibm.bi.dml.runtime.util.IndexRange; import com.ibm.bi.dml.runtime.util.UtilFunctions; public class MatrixIndexingSPInstruction extends UnarySPInstruction { /* * This class implements the matrix indexing functionality inside CP. * Example instructions: * rangeReIndex:mVar1:Var2:Var3:Var4:Var5:mVar6 * input=mVar1, output=mVar6, * bounds = (Var2,Var3,Var4,Var5) * rowindex_lower: Var2, rowindex_upper: Var3 * colindex_lower: Var4, colindex_upper: Var5 * leftIndex:mVar1:mVar2:Var3:Var4:Var5:Var6:mVar7 * triggered by "mVar1[Var3:Var4, Var5:Var6] = mVar2" * the result is stored in mVar7 * */ protected CPOperand rowLower, rowUpper, colLower, colUpper; protected SparkAggType _aggType = null; public MatrixIndexingSPInstruction(Operator op, CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, SparkAggType aggtype, String opcode, String istr) { super(op, in, out, opcode, istr); rowLower = rl; rowUpper = ru; colLower = cl; colUpper = cu; _aggType = aggtype; } public MatrixIndexingSPInstruction(Operator op, CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { super(op, lhsInput, rhsInput, out, opcode, istr); rowLower = rl; rowUpper = ru; colLower = cl; colUpper = cu; } public static Instruction parseInstruction ( String str ) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; if ( opcode.equalsIgnoreCase("rangeReIndex") ) { if ( parts.length == 8 ) { // Example: rangeReIndex:mVar1:Var2:Var3:Var4:Var5:mVar6 CPOperand in = new CPOperand(parts[1]); CPOperand rl = new CPOperand(parts[2]); CPOperand ru = new CPOperand(parts[3]); CPOperand cl = new CPOperand(parts[4]); CPOperand cu = new CPOperand(parts[5]); CPOperand out = new CPOperand(parts[6]); SparkAggType aggtype = SparkAggType.valueOf(parts[7]); return new MatrixIndexingSPInstruction(new SimpleOperator(null), in, rl, ru, cl, cu, out, aggtype, opcode, str); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); } } else if ( opcode.equalsIgnoreCase("leftIndex") || opcode.equalsIgnoreCase("mapLeftIndex")) { if ( parts.length == 8 ) { // Example: leftIndex:mVar1:mvar2:Var3:Var4:Var5:Var6:mVar7 CPOperand lhsInput = new CPOperand(parts[1]); CPOperand rhsInput = new CPOperand(parts[2]); CPOperand rl = new CPOperand(parts[3]); CPOperand ru = new CPOperand(parts[4]); CPOperand cl = new CPOperand(parts[5]); CPOperand cu = new CPOperand(parts[6]); CPOperand out = new CPOperand(parts[7]); return new MatrixIndexingSPInstruction(new SimpleOperator(null), lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); } } else { throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingSPInstruction: " + str); } } @Override public void processInstruction(ExecutionContext ec) throws DMLUnsupportedOperationException, DMLRuntimeException { SparkExecutionContext sec = (SparkExecutionContext)ec; String opcode = getOpcode(); //get indexing range long rl = ec.getScalarInput(rowLower.getName(), rowLower.getValueType(), rowLower.isLiteral()).getLongValue(); long ru = ec.getScalarInput(rowUpper.getName(), rowUpper.getValueType(), rowUpper.isLiteral()).getLongValue(); long cl = ec.getScalarInput(colLower.getName(), colLower.getValueType(), colLower.isLiteral()).getLongValue(); long cu = ec.getScalarInput(colUpper.getName(), colUpper.getValueType(), colUpper.isLiteral()).getLongValue(); IndexRange ixrange = new IndexRange(rl, ru, cl, cu); //right indexing if( opcode.equalsIgnoreCase("rangeReIndex") ) { //update and check output dimensions MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input1.getName()); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); mcOut.set(ru-rl+1, cu-cl+1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); checkValidOutputDimensions(mcOut); //execute right indexing operation (partitioning-preserving if possible) JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; if( isPartitioningPreservingRightIndexing(mcIn, ixrange) ) { out = in1.mapPartitionsToPair( new SliceBlockPartitionFunction(ixrange, mcOut), true); } else{ out = in1.filter(new IsBlockInRange(rl, ru, cl, cu, mcOut)) .flatMapToPair(new SliceBlock(ixrange, mcOut)); //aggregation if required if( _aggType != SparkAggType.NONE ) out = RDDAggregateUtils.mergeByKey(out); } //put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); } //left indexing else if ( opcode.equalsIgnoreCase("leftIndex") || opcode.equalsIgnoreCase("mapLeftIndex")) { JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); PartitionedBroadcastMatrix broadcastIn2 = null; JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = null; JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; //update and check output dimensions MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); MatrixCharacteristics mcLeft = ec.getMatrixCharacteristics(input1.getName()); mcOut.set(mcLeft.getRows(), mcLeft.getCols(), mcLeft.getRowsPerBlock(), mcLeft.getColsPerBlock()); checkValidOutputDimensions(mcOut); //note: always matrix rhs, scalars are preprocessed via cast to 1x1 matrix MatrixCharacteristics mcRight = ec.getMatrixCharacteristics(input2.getName()); //sanity check matching index range and rhs dimensions if(!mcRight.dimsKnown()) { throw new DMLRuntimeException("The right input matrix dimensions are not specified for MatrixIndexingSPInstruction"); } if(!(ru-rl+1 == mcRight.getRows() && cu-cl+1 == mcRight.getCols())) { throw new DMLRuntimeException("Invalid index range of leftindexing: ["+rl+":"+ru+","+cl+":"+cu+"] vs ["+mcRight.getRows()+"x"+mcRight.getCols()+"]." ); } if(opcode.equalsIgnoreCase("mapLeftIndex")) { broadcastIn2 = sec.getBroadcastForVariable( input2.getName() ); //partitioning-preserving mappartitions (key access required for broadcast loopkup) out = in1.mapPartitionsToPair( new LeftIndexPartitionFunction(broadcastIn2, ixrange, mcOut), true); } else { // Zero-out LHS in1 = in1.mapToPair(new ZeroOutLHS(false, mcLeft.getRowsPerBlock(), mcLeft.getColsPerBlock(), rl, ru, cl, cu)); // Slice RHS to merge for LHS in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ) .flatMapToPair(new SliceRHSForLeftIndexing(rl, cl, mcLeft.getRowsPerBlock(), mcLeft.getColsPerBlock(), mcLeft.getRows(), mcLeft.getCols())); out = RDDAggregateUtils.mergeByKey(in1.union(in2)); } sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); if( broadcastIn2 != null) sec.addLineageBroadcast(output.getName(), input2.getName()); if(in2 != null) sec.addLineageRDD(output.getName(), input2.getName()); } else throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingSPInstruction."); } /** * * @param mcOut * @throws DMLRuntimeException */ private static void checkValidOutputDimensions(MatrixCharacteristics mcOut) throws DMLRuntimeException { if(!mcOut.dimsKnown()) { throw new DMLRuntimeException("MatrixIndexingSPInstruction: The updated output dimensions are invalid: " + mcOut); } } /** * * @param mcIn * @param ixrange * @return */ private boolean isPartitioningPreservingRightIndexing(MatrixCharacteristics mcIn, IndexRange ixrange) { return ( mcIn.dimsKnown() && (ixrange.rowStart==1 && ixrange.rowEnd==mcIn.getRows() && mcIn.getCols()<=mcIn.getColsPerBlock() ) //1-1 column block indexing ||(ixrange.colStart==1 && ixrange.colEnd==mcIn.getCols() && mcIn.getRows()<=mcIn.getRowsPerBlock() )); //1-1 row block indexing } /** * */ private static class SliceRHSForLeftIndexing implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 5724800998701216440L; private long rl; private long cl; private int brlen; private int bclen; private long lhs_rlen; private long lhs_clen; public SliceRHSForLeftIndexing(long rl, long cl, int brlen, int bclen, long lhs_rlen, long lhs_clen) { this.rl = rl; this.cl = cl; this.brlen = brlen; this.bclen = bclen; this.lhs_rlen = lhs_rlen; this.lhs_clen = lhs_clen; } @Override public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> rightKV) throws Exception { ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> retVal = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); long start_lhs_globalRowIndex = rl + (rightKV._1.getRowIndex()-1)*brlen; long start_lhs_globalColIndex = cl + (rightKV._1.getColumnIndex()-1)*bclen; long end_lhs_globalRowIndex = start_lhs_globalRowIndex + rightKV._2.getNumRows() - 1; long end_lhs_globalColIndex = start_lhs_globalColIndex + rightKV._2.getNumColumns() - 1; long start_lhs_rowIndex = UtilFunctions.blockIndexCalculation(start_lhs_globalRowIndex, brlen); long end_lhs_rowIndex = UtilFunctions.blockIndexCalculation(end_lhs_globalRowIndex, brlen); long start_lhs_colIndex = UtilFunctions.blockIndexCalculation(start_lhs_globalColIndex, bclen); long end_lhs_colIndex = UtilFunctions.blockIndexCalculation(end_lhs_globalColIndex, bclen); for(long leftRowIndex = start_lhs_rowIndex; leftRowIndex <= end_lhs_rowIndex; leftRowIndex++) { for(long leftColIndex = start_lhs_colIndex; leftColIndex <= end_lhs_colIndex; leftColIndex++) { // Calculate global index of right hand side block long lhs_rl = Math.max((leftRowIndex-1)*brlen+1, start_lhs_globalRowIndex); long lhs_ru = Math.min(leftRowIndex*brlen, end_lhs_globalRowIndex); long lhs_cl = Math.max((leftColIndex-1)*bclen+1, start_lhs_globalColIndex); long lhs_cu = Math.min(leftColIndex*bclen, end_lhs_globalColIndex); int lhs_lrl = UtilFunctions.cellInBlockCalculation(lhs_rl, brlen); int lhs_lru = UtilFunctions.cellInBlockCalculation(lhs_ru, brlen); int lhs_lcl = UtilFunctions.cellInBlockCalculation(lhs_cl, bclen); int lhs_lcu = UtilFunctions.cellInBlockCalculation(lhs_cu, bclen); long rhs_rl = lhs_rl - rl + 1; long rhs_ru = rhs_rl + (lhs_ru - lhs_rl); long rhs_cl = lhs_cl - cl + 1; long rhs_cu = rhs_cl + (lhs_cu - lhs_cl); int rhs_lrl = UtilFunctions.cellInBlockCalculation(rhs_rl, brlen); int rhs_lru = UtilFunctions.cellInBlockCalculation(rhs_ru, brlen); int rhs_lcl = UtilFunctions.cellInBlockCalculation(rhs_cl, bclen); int rhs_lcu = UtilFunctions.cellInBlockCalculation(rhs_cu, bclen); MatrixBlock slicedRHSBlk = rightKV._2.sliceOperations(rhs_lrl, rhs_lru, rhs_lcl, rhs_lcu, new MatrixBlock()); int lbrlen = UtilFunctions.computeBlockSize(lhs_rlen, leftRowIndex, brlen); int lbclen = UtilFunctions.computeBlockSize(lhs_clen, leftColIndex, bclen); MatrixBlock resultBlock = new MatrixBlock(lbrlen, lbclen, false); resultBlock = resultBlock.leftIndexingOperations(slicedRHSBlk, lhs_lrl, lhs_lru, lhs_lcl, lhs_lcu, null, false); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(new MatrixIndexes(leftRowIndex, leftColIndex), resultBlock)); } } return retVal; } } /** * */ private static class ZeroOutLHS implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes,MatrixBlock> { private static final long serialVersionUID = -3581795160948484261L; private boolean complementary = false; private int brlen; int bclen; private IndexRange indexRange; private long rl; long ru; long cl; long cu; public ZeroOutLHS(boolean complementary, int brlen, int bclen, long rl, long ru, long cl, long cu) { this.complementary = complementary; this.brlen = brlen; this.bclen = bclen; this.rl = rl; this.ru = ru; this.cl = cl; this.cu = cu; this.indexRange = new IndexRange(rl, ru, cl, cu); } @Override public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception { if( !UtilFunctions.isInBlockRange(kv._1(), brlen, bclen, rl, ru, cl, cu) ) { return kv; } IndexRange range = UtilFunctions.getSelectedRangeForZeroOut(new IndexedMatrixValue(kv._1, kv._2), brlen, bclen, indexRange); if(range.rowStart == -1 && range.rowEnd == -1 && range.colStart == -1 && range.colEnd == -1) { throw new Exception("Error while getting range for zero-out"); } MatrixBlock zeroBlk = (MatrixBlock) kv._2.zeroOutOperations(new MatrixBlock(), range, complementary); return new Tuple2<MatrixIndexes, MatrixBlock>(kv._1, zeroBlk); } } /** * */ private static class LeftIndexPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,MatrixBlock>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 1757075506076838258L; private PartitionedBroadcastMatrix _binput; private IndexRange _ixrange; private int _brlen; private int _bclen; public LeftIndexPartitionFunction(PartitionedBroadcastMatrix binput, IndexRange ixrange, MatrixCharacteristics mc) { _binput = binput; _ixrange = ixrange; _brlen = mc.getRowsPerBlock(); _bclen = mc.getColsPerBlock(); } @Override public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception { return new LeftIndexPartitionIterator(arg0); } /** * */ private class LeftIndexPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> { public LeftIndexPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { super(in); } @Override protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { if(!UtilFunctions.isInBlockRange(arg._1(), _brlen, _bclen, _ixrange)) { return arg; } // Calculate global index of left hand side block long lhs_rl = Math.max(_ixrange.rowStart, (arg._1.getRowIndex()-1)*_brlen + 1); long lhs_ru = Math.min(_ixrange.rowEnd, arg._1.getRowIndex()*_brlen); long lhs_cl = Math.max(_ixrange.colStart, (arg._1.getColumnIndex()-1)*_bclen + 1); long lhs_cu = Math.min(_ixrange.colEnd, arg._1.getColumnIndex()*_bclen); // Calculate global index of right hand side block long rhs_rl = lhs_rl - _ixrange.rowStart + 1; long rhs_ru = rhs_rl + (lhs_ru - lhs_rl); long rhs_cl = lhs_cl - _ixrange.colStart + 1; long rhs_cu = rhs_cl + (lhs_cu - lhs_cl); // Provide global zero-based index to sliceOperations MatrixBlock slicedRHSMatBlock = _binput.sliceOperations(rhs_rl, rhs_ru, rhs_cl, rhs_cu, new MatrixBlock()); // Provide local zero-based index to leftIndexingOperations int lhs_lrl = UtilFunctions.cellInBlockCalculation(lhs_rl, _brlen); int lhs_lru = UtilFunctions.cellInBlockCalculation(lhs_ru, _brlen); int lhs_lcl = UtilFunctions.cellInBlockCalculation(lhs_cl, _bclen); int lhs_lcu = UtilFunctions.cellInBlockCalculation(lhs_cu, _bclen); MatrixBlock ret = arg._2.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl, lhs_lru, lhs_lcl, lhs_lcu, new MatrixBlock(), false); return new Tuple2<MatrixIndexes, MatrixBlock>(arg._1, ret); } } } /** * */ private static class SliceBlock implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 5733886476413136826L; private IndexRange _ixrange; private int _brlen; private int _bclen; public SliceBlock(IndexRange ixrange, MatrixCharacteristics mcOut) { _ixrange = ixrange; _brlen = mcOut.getRowsPerBlock(); _bclen = mcOut.getColsPerBlock(); } @Override public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception { IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(kv); ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>(); OperationsOnMatrixValues.performSlice(in, _ixrange, _brlen, _bclen, outlist); return SparkUtils.fromIndexedMatrixBlock(outlist); } } /** * */ private static class SliceBlockPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,MatrixBlock>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8111291718258309968L; private IndexRange _ixrange; private int _brlen; private int _bclen; public SliceBlockPartitionFunction(IndexRange ixrange, MatrixCharacteristics mcOut) { _ixrange = ixrange; _brlen = mcOut.getRowsPerBlock(); _bclen = mcOut.getColsPerBlock(); } @Override public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception { return new SliceBlockPartitionIterator(arg0); } private class SliceBlockPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> { public SliceBlockPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { super(in); } @Override protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg); ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>(); OperationsOnMatrixValues.performSlice(in, _ixrange, _brlen, _bclen, outlist); assert(outlist.size() == 1); //1-1 row/column block indexing return SparkUtils.fromIndexedMatrixBlock(outlist.get(0)); } } } }