/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.cp; import java.util.ArrayList; import java.util.Arrays; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.functionobjects.SwapIndex; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.ConvolutionParameters; import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; import org.apache.sysml.runtime.matrix.data.LibMatrixNative; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.util.ConvolutionUtils; import org.apache.sysml.utils.NativeHelper; public class ConvolutionCPInstruction extends UnaryCPInstruction { private CPOperand _in2; private CPOperand _in3; private ArrayList<CPOperand> _input_shape; private ArrayList<CPOperand> _filter_shape; private ArrayList<CPOperand> _stride = new ArrayList<CPOperand>(); private ArrayList<CPOperand> _padding = new ArrayList<CPOperand>(); private int _numThreads = -1; public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads) throws DMLRuntimeException { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); if( !(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply") ) ) { throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode); } _in2 = in2; _cptype = CPINSTRUCTION_TYPE.Convolution; _numThreads = numThreads; } public ConvolutionCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); _cptype = CPINSTRUCTION_TYPE.Convolution; _stride = stride; _padding = padding; _input_shape = input_shape; _filter_shape = filter_shape; _numThreads = numThreads; } public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); _in2 = in2; _cptype = CPINSTRUCTION_TYPE.Convolution; _stride = stride; _padding = padding; _input_shape = input_shape; _filter_shape = filter_shape; _numThreads = numThreads; } public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) { super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); _in2 = in2; _in3 = in3; _cptype = CPINSTRUCTION_TYPE.Convolution; _stride = stride; _padding = padding; _input_shape = input_shape; _filter_shape = filter_shape; _numThreads = numThreads; } public static ConvolutionCPInstruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) { InstructionUtils.checkNumFields(parts, 15); // stride1, stride2, padding1, padding2 // input_shape1, input_shape2, input_shape3, input_shape4, // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k CPOperand in = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[14]); ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); stride.add(new CPOperand(parts[2])); stride.add(new CPOperand(parts[3])); padding.add(new CPOperand(parts[4])); padding.add(new CPOperand(parts[5])); input_shape.add(new CPOperand(parts[6])); input_shape.add(new CPOperand(parts[7])); input_shape.add(new CPOperand(parts[8])); input_shape.add(new CPOperand(parts[9])); filter_shape.add(new CPOperand(parts[10])); filter_shape.add(new CPOperand(parts[11])); filter_shape.add(new CPOperand(parts[12])); filter_shape.add(new CPOperand(parts[13])); int k = Integer.parseInt(parts[15]); return new ConvolutionCPInstruction(in, out, opcode, str, stride, padding, input_shape, filter_shape, k); } else if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward") || opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) { InstructionUtils.checkNumFields(parts, 16); // dout, stride1, stride2, padding1, padding2 // input_shape1, input_shape2, input_shape3, input_shape4, // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k CPOperand in = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[15]); ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); stride.add(new CPOperand(parts[3])); stride.add(new CPOperand(parts[4])); padding.add(new CPOperand(parts[5])); padding.add(new CPOperand(parts[6])); input_shape.add(new CPOperand(parts[7])); input_shape.add(new CPOperand(parts[8])); input_shape.add(new CPOperand(parts[9])); input_shape.add(new CPOperand(parts[10])); filter_shape.add(new CPOperand(parts[11])); filter_shape.add(new CPOperand(parts[12])); filter_shape.add(new CPOperand(parts[13])); filter_shape.add(new CPOperand(parts[14])); int k = Integer.parseInt(parts[16]); return new ConvolutionCPInstruction(in, in2, out, opcode, str, stride, padding, input_shape, filter_shape, k); } else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { InstructionUtils.checkNumFields(parts, 17); // dout, stride1, stride2, padding1, padding2 // input_shape1, input_shape2, input_shape3, input_shape4, // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k CPOperand in = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand in3 = new CPOperand(parts[3]); CPOperand out = new CPOperand(parts[16]); ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); stride.add(new CPOperand(parts[4])); stride.add(new CPOperand(parts[5])); padding.add(new CPOperand(parts[6])); padding.add(new CPOperand(parts[7])); input_shape.add(new CPOperand(parts[8])); input_shape.add(new CPOperand(parts[9])); input_shape.add(new CPOperand(parts[10])); input_shape.add(new CPOperand(parts[11])); filter_shape.add(new CPOperand(parts[12])); filter_shape.add(new CPOperand(parts[13])); filter_shape.add(new CPOperand(parts[14])); filter_shape.add(new CPOperand(parts[15])); int k = Integer.parseInt(parts[17]); return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, k); } else if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) { InstructionUtils.checkNumFields(parts, 4); CPOperand in = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); int k = Integer.parseInt(parts[4]); return new ConvolutionCPInstruction(in, in2, out, opcode, str, k); } else { throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str); } } private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) throws DMLRuntimeException { return (int) ec.getScalarInput(aL.get(index).getName(), aL.get(index).getValueType(), aL.get(index).isLiteral()) .getLongValue(); } @SuppressWarnings("unused") public void processReluBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { // (X > 0) * dout MatrixBlock input = ec.getMatrixInput(input1.getName()); MatrixBlock dout = ec.getMatrixInput(_in2.getName()); MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), LibMatrixDNN.SUPPORTS_SPARSE_OUTPUTS && (input.isInSparseFormat() || dout.isInSparseFormat())); if( !input.isEmpty() && !dout.isEmpty() ) { outputBlock.allocateDenseOrSparseBlock(); LibMatrixDNN.reluBackward(input, dout, outputBlock, _numThreads); } // release inputs/outputs ec.releaseMatrixInput(input1.getName()); ec.releaseMatrixInput(_in2.getName()); ec.setMatrixOutput(getOutputVariableName(), outputBlock); } public void processBiasAddInstruction(ExecutionContext ec) throws DMLRuntimeException { MatrixBlock input = ec.getMatrixInput(input1.getName()); MatrixBlock bias = ec.getMatrixInput(_in2.getName()); MatrixBlock outputBlock = null; if(bias.getNumColumns() != 1) { throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns()); } if(input.isEmpty() && bias.isEmpty()) { outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true); } else if(bias.isEmpty()) { outputBlock = new MatrixBlock(input); } else { // As we always fill the output first with bias outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false); outputBlock.allocateDenseBlock(); LibMatrixDNN.biasAdd(input, bias, outputBlock, _numThreads); } // release inputs/outputs ec.releaseMatrixInput(input1.getName()); ec.releaseMatrixInput(_in2.getName()); ec.setMatrixOutput(getOutputVariableName(), outputBlock); } public void processBiasMultiplyInstruction(ExecutionContext ec) throws DMLRuntimeException { MatrixBlock input = ec.getMatrixInput(input1.getName()); MatrixBlock bias = ec.getMatrixInput(_in2.getName()); MatrixBlock outputBlock = null; if(bias.getNumColumns() != 1) { throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns()); } if(bias.isEmpty()) { // Anything multiplied by zero is zero outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true); } else { // As we always fill the output first with bias outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false); outputBlock.allocateDenseBlock(); LibMatrixDNN.biasMultiply(input, bias, outputBlock, _numThreads); } // release inputs/outputs ec.releaseMatrixInput(input1.getName()); ec.releaseMatrixInput(_in2.getName()); ec.setMatrixOutput(getOutputVariableName(), outputBlock); } // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is true // This increases the number of native calls. For example:the cases where filter is sparse but input is dense private boolean isFilterSparse(MatrixBlock filter) throws DMLRuntimeException { long numElems = filter.getNumRows()*filter.getNumColumns(); // if filter is less than 10 MB in dense format (which handles almost all the cases). // In fact, using threshold of 1 MB is still sufficient for common CNNs. if(filter.isInSparseFormat() && numElems < 10e+6) filter.sparseToDense(); return filter.isInSparseFormat(); } @Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { if (instOpcode.equalsIgnoreCase("bias_add")) { processBiasAddInstruction(ec); return; } else if (instOpcode.equalsIgnoreCase("bias_multiply")) { processBiasMultiplyInstruction(ec); return; } else if (instOpcode.equalsIgnoreCase("relu_backward")) { processReluBackwardInstruction(ec); return; } // acquire inputs MatrixBlock outputBlock = null; MatrixBlock matBlock = ec.getMatrixInput(input1.getName()); int pad_h = getScalarInput(ec, _padding, 0); int pad_w = getScalarInput(ec, _padding, 1); int stride_h = getScalarInput(ec, _stride, 0); int stride_w = getScalarInput(ec, _stride, 1); int N = getScalarInput(ec, _input_shape, 0); int C = getScalarInput(ec, _input_shape, 1); int H = getScalarInput(ec, _input_shape, 2); int W = getScalarInput(ec, _input_shape, 3); int K = getScalarInput(ec, _filter_shape, 0); int R = getScalarInput(ec, _filter_shape, 2); int S = getScalarInput(ec, _filter_shape, 3); int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); ConvolutionParameters params = new ConvolutionParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads); params.enableNative = NativeHelper.isNativeLibraryLoaded(); if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { if(matBlock.isEmpty()) { outputBlock = new MatrixBlock(N, C*P*Q, true); } else { outputBlock = getDenseOutputBlock(N, C*P*Q); if(instOpcode.equalsIgnoreCase("maxpooling")) Arrays.fill(outputBlock.getDenseBlock(), -Double.MAX_VALUE); LibMatrixDNN.maxpooling(matBlock, outputBlock, params); } } else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) { MatrixBlock dout = ec.getMatrixInput(_in2.getName()); if(matBlock.isEmpty() || dout.isEmpty()) { outputBlock = new MatrixBlock(N, C*H*W, true); } else { outputBlock = getDenseOutputBlock(N, C*H*W); if(instOpcode.equalsIgnoreCase("maxpooling_backward")) LibMatrixDNN.maxpoolingBackward(matBlock, dout, outputBlock, params, false); else LibMatrixDNN.maxpoolingBackward(matBlock, dout, outputBlock, params, true); } ec.releaseMatrixInput(_in2.getName()); } else if (instOpcode.equalsIgnoreCase("conv2d")) { MatrixBlock filter = ec.getMatrixInput(_in2.getName()); if(filter.isEmpty() || matBlock.isEmpty()) { outputBlock = new MatrixBlock(N, K*P*Q, true); } else { outputBlock = getDenseOutputBlock(N, K*P*Q); if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat()) LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); else LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); } ec.releaseMatrixInput(_in2.getName()); } else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { MatrixBlock filter = ec.getMatrixInput(_in3.getName()); MatrixBlock bias = ec.getMatrixInput(_in2.getName()); if((filter.isEmpty() || matBlock.isEmpty()) && bias.isEmpty()) { outputBlock = new MatrixBlock(N, K*P*Q, true); } else { outputBlock = getDenseOutputBlock(N, K*P*Q); if(!bias.isEmpty()) { params.bias = bias; } if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat()) LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); else LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); } ec.releaseMatrixInput(_in3.getName()); ec.releaseMatrixInput(_in2.getName()); } else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { MatrixBlock dout = ec.getMatrixInput(_in2.getName()); if(dout.isEmpty() || matBlock.isEmpty()) { outputBlock = new MatrixBlock(K, C*R*S, true); } else { outputBlock = getDenseOutputBlock(K, C*R*S); if(params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat()) LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params); else LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params); } ec.releaseMatrixInput(_in2.getName()); } else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { MatrixBlock dout = ec.getMatrixInput(_in2.getName()); if(dout.isEmpty() || matBlock.isEmpty()) { outputBlock = new MatrixBlock(N, C * H * W, true); } else { outputBlock = getDenseOutputBlock(N, C * H * W); if(params.enableNative && !isFilterSparse(matBlock) && !dout.isInSparseFormat()) LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params); else LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params); } ec.releaseMatrixInput(_in2.getName()); } else { throw new DMLRuntimeException("Unsupported op code " + instOpcode); } // release inputs/outputs ec.releaseMatrixInput(input1.getName()); ec.setMatrixOutput(getOutputVariableName(), outputBlock); } private MatrixBlock getDenseOutputBlock(int numRows, int numCols) throws DMLRuntimeException { MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, false); outputBlock.allocateDenseBlock(); return outputBlock; } }