/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.runtime.instructions.cp; import com.ibm.bi.dml.parser.Expression.DataType; import com.ibm.bi.dml.parser.Expression.ValueType; import com.ibm.bi.dml.runtime.DMLRuntimeException; import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException; import com.ibm.bi.dml.runtime.controlprogram.caching.MatrixObject; import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext; import com.ibm.bi.dml.runtime.instructions.Instruction; import com.ibm.bi.dml.runtime.instructions.InstructionUtils; import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock; import com.ibm.bi.dml.runtime.matrix.operators.Operator; import com.ibm.bi.dml.runtime.matrix.operators.SimpleOperator; import com.ibm.bi.dml.runtime.util.IndexRange; public class MatrixIndexingCPInstruction extends UnaryCPInstruction { /* * This class implements the matrix indexing functionality inside CP. * Example instructions: * rangeReIndex:mVar1:Var2:Var3:Var4:Var5:mVar6 * input=mVar1, output=mVar6, * bounds = (Var2,Var3,Var4,Var5) * rowindex_lower: Var2, rowindex_upper: Var3 * colindex_lower: Var4, colindex_upper: Var5 * leftIndex:mVar1:mVar2:Var3:Var4:Var5:Var6:mVar7 * triggered by "mVar1[Var3:Var4, Var5:Var6] = mVar2" * the result is stored in mVar7 * */ protected CPOperand rowLower, rowUpper, colLower, colUpper; public MatrixIndexingCPInstruction(Operator op, CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr){ super(op, in, out, opcode, istr); rowLower = rl; rowUpper = ru; colLower = cl; colUpper = cu; } public MatrixIndexingCPInstruction(Operator op, CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr){ super(op, lhsInput, rhsInput, out, opcode, istr); rowLower = rl; rowUpper = ru; colLower = cl; colUpper = cu; } public static Instruction parseInstruction ( String str ) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; if ( opcode.equalsIgnoreCase("rangeReIndex") ) { if ( parts.length == 7 ) { // Example: rangeReIndex:mVar1:Var2:Var3:Var4:Var5:mVar6 CPOperand in, rl, ru, cl, cu, out; in = new CPOperand(); rl = new CPOperand(); ru = new CPOperand(); cl = new CPOperand(); cu = new CPOperand(); out = new CPOperand(); in.split(parts[1]); rl.split(parts[2]); ru.split(parts[3]); cl.split(parts[4]); cu.split(parts[5]); out.split(parts[6]); return new MatrixIndexingCPInstruction(new SimpleOperator(null), in, rl, ru, cl, cu, out, opcode, str); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); } } else if ( opcode.equalsIgnoreCase("leftIndex")) { if ( parts.length == 8 ) { // Example: leftIndex:mVar1:mvar2:Var3:Var4:Var5:Var6:mVar7 CPOperand lhsInput, rhsInput, rl, ru, cl, cu, out; lhsInput = new CPOperand(); rhsInput = new CPOperand(); rl = new CPOperand(); ru = new CPOperand(); cl = new CPOperand(); cu = new CPOperand(); out = new CPOperand(); lhsInput.split(parts[1]); rhsInput.split(parts[2]); rl.split(parts[3]); ru.split(parts[4]); cl.split(parts[5]); cu.split(parts[6]); out.split(parts[7]); return new MatrixIndexingCPInstruction(new SimpleOperator(null), lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); } } else { throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingCPInstruction: " + str); } } @Override public void processInstruction(ExecutionContext ec) throws DMLUnsupportedOperationException, DMLRuntimeException { String opcode = getOpcode(); //get indexing range int rl = (int)(ec.getScalarInput(rowLower.getName(), rowLower.getValueType(), rowLower.isLiteral()).getLongValue()-1); int ru = (int)(ec.getScalarInput(rowUpper.getName(), rowUpper.getValueType(), rowUpper.isLiteral()).getLongValue()-1); int cl = (int)(ec.getScalarInput(colLower.getName(), colLower.getValueType(), colLower.isLiteral()).getLongValue()-1); int cu = (int)(ec.getScalarInput(colUpper.getName(), colUpper.getValueType(), colUpper.isLiteral()).getLongValue()-1); //get original matrix MatrixObject mo = (MatrixObject)ec.getVariable(input1.getName()); //right indexing if( opcode.equalsIgnoreCase("rangeReIndex") ) { MatrixBlock resultBlock = null; if( mo.isPartitioned() ) //via data partitioning resultBlock = mo.readMatrixPartition( new IndexRange(rl+1,ru+1,cl+1,cu+1) ); else //via slicing the in-memory matrix { //execute right indexing operation MatrixBlock matBlock = ec.getMatrixInput(input1.getName()); resultBlock = matBlock.sliceOperations(rl, ru, cl, cu, new MatrixBlock()); //unpin rhs input ec.releaseMatrixInput(input1.getName()); //ensure correct sparse/dense output representation //(memory guarded by release of input) resultBlock.examSparsity(); } //unpin output ec.setMatrixOutput(output.getName(), resultBlock); } //left indexing else if ( opcode.equalsIgnoreCase("leftIndex")) { boolean inplace = mo.isUpdateInPlaceEnabled(); MatrixBlock matBlock = ec.getMatrixInput(input1.getName()); MatrixBlock resultBlock = null; if(input2.getDataType() == DataType.MATRIX) //MATRIX<-MATRIX { MatrixBlock rhsMatBlock = ec.getMatrixInput(input2.getName()); resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, rl, ru, cl, cu, new MatrixBlock(), inplace); ec.releaseMatrixInput(input2.getName()); } else //MATRIX<-SCALAR { if(!(rl==ru && cl==cu)) throw new DMLRuntimeException("Invalid index range of scalar leftindexing: ["+rl+":"+ru+","+cl+":"+cu+"]." ); ScalarObject scalar = ec.getScalarInput(input2.getName(), ValueType.DOUBLE, input2.isLiteral()); resultBlock = (MatrixBlock) matBlock.leftIndexingOperations(scalar, rl, cl, new MatrixBlock(), inplace); } //unpin lhs input ec.releaseMatrixInput(input1.getName()); //ensure correct sparse/dense output representation //(memory guarded by release of input) resultBlock.examSparsity(); //unpin output ec.setMatrixOutput(output.getName(), resultBlock, inplace); } else throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingCPInstruction."); } }