/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.instructions.spark;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.controlprogram.context.SparkExecutionContext;
import com.ibm.bi.dml.runtime.instructions.cp.CPOperand;
import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
import com.ibm.bi.dml.runtime.matrix.operators.Operator;
public abstract class ComputationSPInstruction extends SPInstruction {
public CPOperand output;
public CPOperand input1, input2, input3;
public ComputationSPInstruction ( Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr ) {
super(op, opcode, istr);
input1 = in1;
input2 = in2;
input3 = null;
output = out;
}
public ComputationSPInstruction ( Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr ) {
super(op, opcode, istr);
input1 = in1;
input2 = in2;
input3 = in3;
output = out;
}
public String getOutputVariableName() {
return output.getName();
}
/**
*
* @param sec
* @throws DMLRuntimeException
*/
protected void updateUnaryOutputMatrixCharacteristics(SparkExecutionContext sec)
throws DMLRuntimeException
{
updateUnaryOutputMatrixCharacteristics(sec, input1.getName(), output.getName());
}
/**
*
* @param sec
* @param nameIn
* @param nameOut
* @throws DMLRuntimeException
*/
protected void updateUnaryOutputMatrixCharacteristics(SparkExecutionContext sec, String nameIn, String nameOut)
throws DMLRuntimeException
{
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(nameIn);
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(nameOut);
if(!mcOut.dimsKnown()) {
if(!mc1.dimsKnown())
throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + mc1.toString() + " " + mcOut.toString());
else
mcOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
}
/**
*
* @param sec
* @throws DMLRuntimeException
*/
protected void updateBinaryOutputMatrixCharacteristics(SparkExecutionContext sec)
throws DMLRuntimeException
{
MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(input2.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
boolean outer = (mcIn1.getRows()>1 && mcIn1.getCols()==1 && mcIn2.getRows()==1 && mcIn2.getCols()>1);
if(!mcOut.dimsKnown()) {
if(!mcIn1.dimsKnown())
throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + mcIn1.toString() + " " + mcIn2.toString() + " " + mcOut.toString());
else if(outer)
sec.getMatrixCharacteristics(output.getName()).set(mcIn1.getRows(), mcIn2.getCols(), mcIn1.getRowsPerBlock(), mcIn2.getColsPerBlock());
else
sec.getMatrixCharacteristics(output.getName()).set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getRowsPerBlock());
}
}
}