/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.NativeHelper;
import scala.Tuple2;
public class ConvolutionSPInstruction extends UnarySPInstruction {
private CPOperand _in2;
private CPOperand _in3;
private ArrayList<CPOperand> _input_shape;
private ArrayList<CPOperand> _filter_shape;
private ArrayList<CPOperand> _stride = new ArrayList<CPOperand>();
private ArrayList<CPOperand> _padding = new ArrayList<CPOperand>();
public ConvolutionSPInstruction(CPOperand in, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
opcode, istr);
_sptype = SPINSTRUCTION_TYPE.Convolution;
_stride = stride;
_padding = padding;
_input_shape = input_shape;
_filter_shape = filter_shape;
}
public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out,
String opcode, String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
opcode, istr);
_in2 = in2;
_sptype = SPINSTRUCTION_TYPE.Convolution;
_stride = stride;
_padding = padding;
_input_shape = input_shape;
_filter_shape = filter_shape;
}
public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand in3,
CPOperand out, String opcode, String istr,
ArrayList<CPOperand> stride, ArrayList<CPOperand> padding,
ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
opcode, istr);
_in2 = in2;
_in3 = in3;
_sptype = SPINSTRUCTION_TYPE.Convolution;
_stride = stride;
_padding = padding;
_input_shape = input_shape;
_filter_shape = filter_shape;
}
public ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out,
String opcode, String istr) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out,
opcode, istr);
_in2 = in2;
_sptype = SPINSTRUCTION_TYPE.Convolution;
}
public static ConvolutionSPInstruction parseInstruction( String str ) throws DMLRuntimeException {
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) {
InstructionUtils.checkNumFields(parts, 14);
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
in.split(parts[1]);
out.split(parts[14]);
ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
stride.add(new CPOperand(parts[2]));
stride.add(new CPOperand(parts[3]));
padding.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
input_shape.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
filter_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
return new ConvolutionSPInstruction(in, out, opcode, str, stride,
padding, input_shape, filter_shape);
}
else if (opcode.equalsIgnoreCase("maxpooling_backward")
|| opcode.equalsIgnoreCase("conv2d")
|| opcode.equalsIgnoreCase("conv2d_backward_filter")
|| opcode.equalsIgnoreCase("conv2d_backward_data")) {
InstructionUtils.checkNumFields(parts, 15);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
in.split(parts[1]);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
in2.split(parts[2]);
out.split(parts[15]);
ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
return new ConvolutionSPInstruction(in, in2, out, opcode, str, stride,
padding, input_shape, filter_shape);
}
else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
InstructionUtils.checkNumFields(parts, 16);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
in.split(parts[1]);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
in2.split(parts[2]);
CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
in3.split(parts[3]);
out.split(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
stride.add(new CPOperand(parts[4]));
stride.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
padding.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
input_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
filter_shape.add(new CPOperand(parts[15]));
return new ConvolutionSPInstruction(in, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape);
}
else if (opcode.equalsIgnoreCase("bias_add")) {
InstructionUtils.checkNumFields(parts, 3);
in.split(parts[1]);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
in2.split(parts[2]);
out.split(parts[3]);
return new ConvolutionSPInstruction(in, in2, out, opcode, str);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
}
}
private JavaPairRDD<MatrixIndexes,MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) throws DMLRuntimeException {
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( name );
MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(name);
if(mcRdd.getColsPerBlock() < mcRdd.getCols() || mcRdd.getRowsPerBlock() != 1) {
MatrixCharacteristics mcOut = new MatrixCharacteristics(mcRdd);
mcOut.setColsPerBlock((int)mcRdd.getCols());
mcOut.setRowsPerBlock(numRowsPerBlock);
in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut)));
// TODO: Inject checkpoint to avoid doing this repeated for validation set
// sec.setRDDHandleForVariable(name, in1);
// sec.setMetaData(name, new MatrixDimensionsMetaData(mcOut));
}
return in1;
}
private Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) throws DMLRuntimeException {
MatrixBlock mb = sec.getMatrixInput( name );
sec.releaseMatrixInput(name);
return sec.getSparkContext().broadcast(mb);
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext)ec;
if(instOpcode.equalsIgnoreCase("conv2d") || instOpcode.equalsIgnoreCase("conv2d_bias_add")
|| instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
String rddVar = input1.getName();
int numRowsPerBlock = 1;
JavaPairRDD<MatrixIndexes,MatrixBlock> inputRDD = reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock);
MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar);
// ------------------------------------
// TODO: Handle large filters > 2G
Broadcast<MatrixBlock> filterBroadcast = null;
Broadcast<MatrixBlock> biasBroadcast = null;
if(instOpcode.equalsIgnoreCase("conv2d")) {
filterBroadcast = getBroadcast(sec, _in2.getName());
}
else if(instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
filterBroadcast = getBroadcast(sec, _in3.getName());
biasBroadcast = getBroadcast(sec, _in2.getName());
}
// ------------------------------------
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
int stride_h = getScalarInput(ec, _stride, 0);
int stride_w = getScalarInput(ec, _stride, 1);
// int N = getScalarInput(ec, _input_shape, 0);
int C = getScalarInput(ec, _input_shape, 1);
int H = getScalarInput(ec, _input_shape, 2);
int W = getScalarInput(ec, _input_shape, 3);
int K = getScalarInput(ec, _filter_shape, 0);
int R = getScalarInput(ec, _filter_shape, 2);
int S = getScalarInput(ec, _filter_shape, 3);
int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h);
int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w);
ConvolutionParameters params = new ConvolutionParameters(numRowsPerBlock, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, 1);
boolean enableNativeBLAS = NativeHelper.isNativeLibraryLoaded();
JavaPairRDD<MatrixIndexes,MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, instOpcode, biasBroadcast, mcRdd.getRows(), enableNativeBLAS), true);
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
long nnz = -1; // TODO: Handle nnz
long numCols = ((long)K)*((long)P)*((long)Q);
if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
numCols = ((long)C)*((long)P)*((long)Q);
}
if(numCols > Integer.MAX_VALUE) {
throw new DMLRuntimeException("The current operator doesnot support large outputs.");
}
sec.setMetaData(output.getName(),
new MatrixFormatMetaData(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, (int)numCols, nnz), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
}
else {
throw new DMLRuntimeException("Not implemented: " + instOpcode);
}
}
private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL,
int index) throws DMLRuntimeException {
return (int) ec.getScalarInput(aL.get(index).getName(),
aL.get(index).getValueType(), aL.get(index).isLiteral())
.getLongValue();
}
private static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
// PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> {
private static final long serialVersionUID = -2106155380020232155L;
Broadcast<MatrixBlock> filterBroadcast = null;
Broadcast<MatrixBlock> biasBroadcast = null;
ConvolutionParameters params = null;
String instOpcode = null; boolean enableNative;
long numRows = 0;
public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast,
ConvolutionParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows, boolean enableNativeBLAS) {
this.filterBroadcast = filterBroadcast;
this.params = params;
this.instOpcode = instOpcode;
this.biasBroadcast = biasBroadcast;
this.numRows = numRows;
this.enableNative = enableNativeBLAS;
}
private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception {
MatrixBlock outputBlock = null;
if(instOpcode.equalsIgnoreCase("conv2d")) {
MatrixBlock filter = filterBroadcast.getValue();
if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) {
outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true);
}
else {
outputBlock = getDenseOutputBlock(params.N, params.K*params.P*params.Q);
if(enableNative)
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
}
else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
MatrixBlock filter = filterBroadcast.getValue();
MatrixBlock bias = biasBroadcast.getValue();
if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) {
outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true);
}
else {
outputBlock = getDenseOutputBlock(params.N, params.K*params.P*params.Q);
if(!bias.isEmptyBlock())
params.bias = bias;
if(enableNative)
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
}
else if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
if(matBlock.isEmptyBlock()) {
outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true);
}
else {
outputBlock = getDenseOutputBlock(params.N, params.C*params.P*params.Q);
if(instOpcode.equalsIgnoreCase("maxpooling"))
Arrays.fill(outputBlock.getDenseBlock(), -Double.MAX_VALUE);
LibMatrixDNN.maxpooling(matBlock, outputBlock, params);
}
}
else {
throw new RuntimeException("Not implemented");
}
return outputBlock;
}
private MatrixBlock getDenseOutputBlock(int numRows, int numCols) throws DMLRuntimeException {
MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, false);
outputBlock.allocateDenseBlock();
return outputBlock;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(
Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0)
throws Exception {
return new MapsideConvolutionPartitionIterator(arg0);
}
// Avoid materialization of partitions
private class MapsideConvolutionPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
public MapsideConvolutionPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
super(in);
}
@Override
protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
if(arg._1.getRowIndex() > numRows || arg._1.getColumnIndex() != 1) {
throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD");
}
MatrixBlock out = processRectangularBlock(arg._2);
if(out.getNumRows() != 1) {
throw new RuntimeException("Expected the output to have 1 row");
}
return new Tuple2<MatrixIndexes, MatrixBlock>(arg._1, out);
}
}
}
}