/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.lops.BinaryM.VectorType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
public abstract class BinarySPInstruction extends ComputationSPInstruction
{
public BinarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr ){
super(op, in1, in2, out, opcode, istr);
}
protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out)
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
InstructionUtils.checkNumFields ( parts, 3 );
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
return opcode;
}
protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out)
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
InstructionUtils.checkNumFields ( parts, 4 );
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
in3.split(parts[3]);
out.split(parts[4]);
return opcode;
}
/**
* Common binary matrix-matrix process instruction
*
* @param ec execution context
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
protected void processMatrixMatrixBinaryInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
//sanity check dimensions
checkMatrixMatrixBinaryCharacteristics(sec);
// Get input RDDs
String rddVar1 = input1.getName();
String rddVar2 = input2.getName();
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar1 );
JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable( rddVar2 );
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( rddVar1 );
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( rddVar2 );
BinaryOperator bop = (BinaryOperator) _optr;
//vector replication if required (mv or outer operations)
boolean rowvector = (mc2.getRows()==1 && mc1.getRows()>1);
long numRepLeft = getNumReplicas(mc1, mc2, true);
long numRepRight = getNumReplicas(mc1, mc2, false);
if( numRepLeft > 1 )
in1 = in1.flatMapToPair(new ReplicateVectorFunction(false, numRepLeft ));
if( numRepRight > 1 )
in2 = in2.flatMapToPair(new ReplicateVectorFunction(rowvector, numRepRight));
//execute binary operation
JavaPairRDD<MatrixIndexes,MatrixBlock> out = in1
.join(in2)
.mapValues(new MatrixMatrixBinaryOpFunction(bop));
//set output RDD
updateBinaryOutputMatrixCharacteristics(sec);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar1);
sec.addLineageRDD(output.getName(), rddVar2);
}
protected void processMatrixBVectorBinaryInstruction(ExecutionContext ec, VectorType vtype)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
//sanity check dimensions
checkMatrixMatrixBinaryCharacteristics(sec);
//get input RDDs
String rddVar = input1.getName();
String bcastVar = input2.getName();
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable( bcastVar );
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(rddVar);
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(bcastVar);
BinaryOperator bop = (BinaryOperator) _optr;
boolean isOuter = (mc1.getRows()>1 && mc1.getCols()==1 && mc2.getRows()==1 && mc2.getCols()>1);
//execute map binary operation
JavaPairRDD<MatrixIndexes,MatrixBlock> out = null;
if( isOuter ) {
out = in1.flatMapToPair(new OuterVectorBinaryOpFunction(bop, in2));
}
else { //default
//note: we use mappartition in order to preserve partitioning information for
//binary mv operations where the keys are guaranteed not to change, the reason
//why we cannot use mapValues is the need for broadcast key lookups.
//alternative: out = in1.mapToPair(new MatrixVectorBinaryOpFunction(bop, in2, vtype));
out = in1.mapPartitionsToPair(
new MatrixVectorBinaryOpPartitionFunction(bop, in2, vtype), true);
}
//set output RDD
updateBinaryOutputMatrixCharacteristics(sec);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
sec.addLineageBroadcast(output.getName(), bcastVar);
}
protected void processMatrixScalarBinaryInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
//get input RDD
String rddVar = (input1.getDataType() == DataType.MATRIX) ? input1.getName() : input2.getName();
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
//get operator and scalar
CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1;
ScalarObject constant = (ScalarObject) ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral());
ScalarOperator sc_op = (ScalarOperator) _optr;
sc_op.setConstant(constant.getDoubleValue());
//execute scalar matrix arithmetic instruction
JavaPairRDD<MatrixIndexes,MatrixBlock> out = in1.mapValues( new MatrixScalarUnaryFunction(sc_op) );
//put output RDD handle into symbol table
updateUnaryOutputMatrixCharacteristics(sec, rddVar, output.getName());
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
}
protected void updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sec, boolean checkCommonDim)
throws DMLRuntimeException
{
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
if(!mcOut.dimsKnown()) {
if( !mc1.dimsKnown() || !mc2.dimsKnown() )
throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
else if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock())
throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
else if(checkCommonDim && mc1.getCols() != mc2.getRows())
throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
else {
mcOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
}
}
protected void updateBinaryAppendOutputMatrixCharacteristics(SparkExecutionContext sec, boolean cbind)
throws DMLRuntimeException
{
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
//infer initially unknown dimensions from inputs
if(!mcOut.dimsKnown()) {
if( !mc1.dimsKnown() || !mc2.dimsKnown() )
throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
if( cbind )
mcOut.set(mc1.getRows(), mc1.getCols()+mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
else //rbind
mcOut.set(mc1.getRows()+mc2.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
//infer initially unknown nnz from inputs
if( !mcOut.nnzKnown() && mc1.nnzKnown() && mc2.nnzKnown() ) {
mcOut.setNonZeros( mc1.getNonZeros() + mc2.getNonZeros() );
}
}
protected long getNumReplicas(MatrixCharacteristics mc1, MatrixCharacteristics mc2, boolean left)
{
if( left )
{
if(mc1.getCols()==1 ) //outer
return (long) Math.ceil((double)mc2.getCols() / mc2.getColsPerBlock());
}
else
{
if(mc2.getRows()==1 && mc1.getRows()>1) //outer, row vector
return (long) Math.ceil((double)mc1.getRows() / mc1.getRowsPerBlock());
else if( mc2.getCols()==1 && mc1.getCols()>1 ) //col vector
return (long) Math.ceil((double)mc1.getCols() / mc1.getColsPerBlock());
}
return 1; //matrix-matrix
}
protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sec)
throws DMLRuntimeException
{
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
//check for unknown input dimensions
if( !(mc1.dimsKnown() && mc2.dimsKnown()) ){
throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: "
+ "[" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
}
//check for dimension mismatch
if( (mc1.getRows() != mc2.getRows() || mc1.getCols() != mc2.getCols())
&& !(mc1.getRows() == mc2.getRows() && mc2.getCols()==1 ) //matrix-colvector
&& !(mc1.getCols() == mc2.getCols() && mc2.getRows()==1 ) //matrix-rowvector
&& !(mc1.getCols()==1 && mc2.getRows()==1) ) //outer colvector-rowvector
{
throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: "
+ "[" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
}
if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: "
+ "[" + mc1.getRowsPerBlock() + "x" + mc1.getColsPerBlock() + " vs " + mc2.getRowsPerBlock() + "x" + mc2.getColsPerBlock() + "]");
}
}
protected void checkBinaryAppendInputCharacteristics(SparkExecutionContext sec, boolean cbind, boolean checkSingleBlk, boolean checkAligned)
throws DMLRuntimeException
{
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
if(!mc1.dimsKnown() || !mc2.dimsKnown()) {
throw new DMLRuntimeException("The dimensions unknown for inputs");
}
else if(cbind && mc1.getRows() != mc2.getRows()) {
throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
}
else if(!cbind && mc1.getCols() != mc2.getCols()) {
throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
}
else if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
throw new DMLRuntimeException("The block sizes donot match for input matrices");
}
if( checkSingleBlk ) {
if(mc1.getCols() + mc2.getCols() > mc1.getColsPerBlock())
throw new DMLRuntimeException("Output must have at most one column block");
}
if( checkAligned ) {
if( mc1.getCols() % mc1.getColsPerBlock() != 0 )
throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
}
}
}