/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.CTable;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.CTableMap;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.util.LongLongDoubleHashMap.LLDoubleEntry;
import org.apache.sysml.runtime.util.UtilFunctions;
public class TernarySPInstruction extends ComputationSPInstruction
{
private String _outDim1;
private String _outDim2;
private boolean _dim1Literal;
private boolean _dim2Literal;
private boolean _isExpand;
private boolean _ignoreZeros;
public TernarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String outputDim1, boolean dim1Literal,String outputDim2, boolean dim2Literal,
boolean isExpand, boolean ignoreZeros, String opcode, String istr )
{
super(op, in1, in2, in3, out, opcode, istr);
_outDim1 = outputDim1;
_dim1Literal = dim1Literal;
_outDim2 = outputDim2;
_dim2Literal = dim2Literal;
_isExpand = isExpand;
_ignoreZeros = ignoreZeros;
}
public static TernarySPInstruction parseInstruction(String inst)
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
InstructionUtils.checkNumFields ( parts, 7 );
String opcode = parts[0];
//handle opcode
if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) {
throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst);
}
boolean isExpand = opcode.equalsIgnoreCase("ctableexpand");
//handle operands
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
//handle known dimension information
String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX);
String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX);
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
// ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject
return new TernarySPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
//get input rdd handle
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = null;
JavaPairRDD<MatrixIndexes,MatrixBlock> in3 = null;
double scalar_input2 = -1, scalar_input3 = -1;
Ternary.OperationTypes ctableOp = Ternary.findCtableOperationByInputDataTypes(
input1.getDataType(), input2.getDataType(), input3.getDataType());
ctableOp = _isExpand ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
// First get the block sizes and then set them as -1 to allow for binary cell reblock
int brlen = mc1.getRowsPerBlock();
int bclen = mc1.getColsPerBlock();
JavaPairRDD<MatrixIndexes, ArrayList<MatrixBlock>> inputMBs = null;
JavaPairRDD<MatrixIndexes, CTableMap> ctables = null;
JavaPairRDD<MatrixIndexes, Double> bincellsNoFilter = null;
boolean setLineage2 = false;
boolean setLineage3 = false;
switch(ctableOp) {
case CTABLE_TRANSFORM: //(VECTOR)
// F=ctable(A,B,W)
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() );
setLineage2 = true;
setLineage3 = true;
inputMBs = in1.cogroup(in2).cogroup(in3)
.mapToPair(new MapThreeMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR)
// F = ctable(seq,A) or F = ctable(seq,B,1)
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
if(scalar_input3 == 1) {
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
setLineage2 = true;
bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen));
break;
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX)
// F = ctable(A,B) or F = ctable(A,B,1)
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
setLineage2 = true;
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
// F=ctable(A,1) or F = ctable(A,1,1)
scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
inputMBs = in1.mapToPair(new MapMBIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR)
// F=ctable(A,1,W)
in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() );
setLineage3 = true;
scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
default:
throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString());
}
// Now perform aggregation on ctables to get binaryCells
if(bincellsNoFilter == null && ctables != null) {
bincellsNoFilter =
ctables.values()
.flatMapToPair(new ExtractBinaryCellsFromCTable());
bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter);
}
else if(!(bincellsNoFilter != null && ctables == null)) {
throw new DMLRuntimeException("Incorrect ctable operation");
}
// handle known/unknown dimensions
long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
MatrixCharacteristics mcBinaryCells = null;
boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1);
if( !findDimensions ) {
if((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1))
throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2);
else
mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen);
// filtering according to given dimensions
bincellsNoFilter = bincellsNoFilter
.filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols()));
}
// convert double values to matrix cell
JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter
.mapToPair(new ConvertToBinaryCell());
// find dimensions if necessary (w/ cache for reblock)
if( findDimensions ) {
binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells);
mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells);
}
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), binaryCells);
mcOut.set(mcBinaryCells);
// Since we are outputing binary cells, we set block sizes = -1
mcOut.setRowsPerBlock(-1); mcOut.setColsPerBlock(-1);
sec.addLineageRDD(output.getName(), input1.getName());
if(setLineage2)
sec.addLineageRDD(output.getName(), input2.getName());
if(setLineage3)
sec.addLineageRDD(output.getName(), input3.getName());
}
private static class ExpandScalarCtableOperation implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, Double>
{
private static final long serialVersionUID = -12552669148928288L;
private int _brlen;
public ExpandScalarCtableOperation(int brlen) {
_brlen = brlen;
}
@Override
public Iterator<Tuple2<MatrixIndexes, Double>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception
{
MatrixIndexes ix = arg0._1();
MatrixBlock mb = arg0._2(); //col-vector
//create an output cell per matrix block row (aligned w/ original source position)
ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<Tuple2<MatrixIndexes,Double>>();
CTable ctab = CTable.getCTableFnObject();
for( int i=0; i<mb.getNumRows(); i++ )
{
//compute global target indexes (via ctable obj for error handling consistency)
long row = UtilFunctions.computeCellIndex(ix.getRowIndex(), _brlen, i);
double v2 = mb.quickGetValue(i, 0);
Pair<MatrixIndexes,Double> p = ctab.execute(row, v2, 1.0);
//indirect construction over pair to avoid tuple2 dependency in general ctable obj
if( p.getKey().getRowIndex() >= 1 ) //filter rejected entries
retVal.add(new Tuple2<MatrixIndexes,Double>(p.getKey(), p.getValue()));
}
return retVal.iterator();
}
}
private static class MapTwoMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> {
private static final long serialVersionUID = 271459913267735850L;
private MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception {
for(MatrixBlock blk1 : blks) {
if(retVal != null) {
throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
}
retVal = blk1;
}
if(retVal == null) {
throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
}
return retVal;
}
@Override
public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(
Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>> kv)
throws Exception {
MatrixBlock in1 = null; MatrixBlock in2 = null;
in1 = extractBlock(kv._2._1, in1);
in2 = extractBlock(kv._2._2, in2);
// Now return unflatten AL
ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
inputs.add(in1); inputs.add(in2);
return new Tuple2<MatrixIndexes, ArrayList<MatrixBlock>>(kv._1, inputs);
}
}
private static class MapThreeMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> {
private static final long serialVersionUID = -4873754507037646974L;
private MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception {
for(MatrixBlock blk1 : blks) {
if(retVal != null) {
throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
}
retVal = blk1;
}
if(retVal == null) {
throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
}
return retVal;
}
@Override
public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(
Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, Iterable<MatrixBlock>>> kv)
throws Exception {
MatrixBlock in1 = null; MatrixBlock in2 = null; MatrixBlock in3 = null;
for(Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>> blks : kv._2._1) {
in1 = extractBlock(blks._1, in1);
in2 = extractBlock(blks._2, in2);
}
in3 = extractBlock(kv._2._2, in3);
// Now return unflatten AL
ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
inputs.add(in1); inputs.add(in2); inputs.add(in3);
return new Tuple2<MatrixIndexes, ArrayList<MatrixBlock>>(kv._1, inputs);
}
}
private static class PerformCTableMapSideOperation implements PairFunction<Tuple2<MatrixIndexes,ArrayList<MatrixBlock>>, MatrixIndexes, CTableMap> {
private static final long serialVersionUID = 5348127596473232337L;
Ternary.OperationTypes ctableOp;
double scalar_input2; double scalar_input3;
String instString;
Operator optr;
boolean ignoreZeros;
public PerformCTableMapSideOperation(Ternary.OperationTypes ctableOp, double scalar_input2, double scalar_input3, String instString, Operator optr, boolean ignoreZeros) {
this.ctableOp = ctableOp;
this.scalar_input2 = scalar_input2;
this.scalar_input3 = scalar_input3;
this.instString = instString;
this.optr = optr;
this.ignoreZeros = ignoreZeros;
}
private void expectedALSize(int length, ArrayList<MatrixBlock> al) throws Exception {
if(al.size() != length) {
throw new Exception("Expected arraylist of size:" + length + ", but found " + al.size());
}
}
@Override
public Tuple2<MatrixIndexes, CTableMap> call(
Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> kv) throws Exception {
CTableMap ctableResult = new CTableMap();
MatrixBlock ctableResultBlock = null;
IndexedMatrixValue in1, in2, in3 = null;
in1 = new IndexedMatrixValue(kv._1, kv._2.get(0));
MatrixBlock matBlock1 = kv._2.get(0);
switch( ctableOp )
{
case CTABLE_TRANSFORM: {
in2 = new IndexedMatrixValue(kv._1, kv._2.get(1));
in3 = new IndexedMatrixValue(kv._1, kv._2.get(2));
expectedALSize(3, kv._2);
if(in1==null || in2==null || in3 == null )
break;
else
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(),
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT:
case CTABLE_EXPAND_SCALAR_WEIGHT:
{
// 3rd input is a scalar
in2 = new IndexedMatrixValue(kv._1, kv._2.get(1));
expectedALSize(2, kv._2);
if(in1==null || in2==null )
break;
else
matBlock1.ternaryOperations((SimpleOperator)optr, kv._2.get(1), scalar_input3, ignoreZeros, ctableResult, ctableResultBlock);
break;
}
case CTABLE_TRANSFORM_HISTOGRAM: {
expectedALSize(1, kv._2);
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2,
scalar_input3, ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
// 2nd and 3rd inputs are scalars
expectedALSize(2, kv._2);
in3 = new IndexedMatrixValue(kv._1, kv._2.get(1)); // Note: kv._2.get(1), not kv._2.get(2)
if(in1==null || in3==null)
break;
else
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2,
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
default:
throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString);
}
return new Tuple2<MatrixIndexes, CTableMap>(kv._1, ctableResult);
}
}
private static class MapMBIntoAL implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, ArrayList<MatrixBlock>> {
private static final long serialVersionUID = 2068398913653350125L;
@Override
public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(
Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception {
ArrayList<MatrixBlock> retVal = new ArrayList<MatrixBlock>();
retVal.add(kv._2);
return new Tuple2<MatrixIndexes, ArrayList<MatrixBlock>>(kv._1, retVal);
}
}
private static class ExtractBinaryCellsFromCTable implements PairFlatMapFunction<CTableMap, MatrixIndexes, Double> {
private static final long serialVersionUID = -5933677686766674444L;
@SuppressWarnings("deprecation")
@Override
public Iterator<Tuple2<MatrixIndexes, Double>> call(CTableMap ctableMap)
throws Exception {
ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<Tuple2<MatrixIndexes, Double>>();
for(LLDoubleEntry ijv : ctableMap.entrySet()) {
long i = ijv.key1;
long j = ijv.key2;
double v = ijv.value;
// retVal.add(new Tuple2<MatrixIndexes, MatrixCell>(blockIndexes, cell));
retVal.add(new Tuple2<MatrixIndexes, Double>(new MatrixIndexes(i, j), v));
}
return retVal.iterator();
}
}
private static class ConvertToBinaryCell implements PairFunction<Tuple2<MatrixIndexes,Double>, MatrixIndexes, MatrixCell> {
private static final long serialVersionUID = 7481186480851982800L;
@Override
public Tuple2<MatrixIndexes, MatrixCell> call(
Tuple2<MatrixIndexes, Double> kv) throws Exception {
MatrixCell cell = new MatrixCell(kv._2().doubleValue());
return new Tuple2<MatrixIndexes, MatrixCell>(kv._1(), cell);
}
}
private static class FilterCells implements Function<Tuple2<MatrixIndexes,Double>, Boolean> {
private static final long serialVersionUID = 108448577697623247L;
long rlen; long clen;
public FilterCells(long rlen, long clen) {
this.rlen = rlen;
this.clen = clen;
}
@Override
public Boolean call(Tuple2<MatrixIndexes, Double> kv) throws Exception {
if(kv._1.getRowIndex() <= 0 || kv._1.getColumnIndex() <= 0) {
throw new Exception("Incorrect cell values in TernarySPInstruction:" + kv._1);
}
if(kv._1.getRowIndex() <= rlen && kv._1.getColumnIndex() <= clen) {
return true;
}
return false;
}
}
}