/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.data;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.utils.NativeHelper;
import org.apache.sysml.utils.Statistics;
public class LibMatrixNative {
// We could encapsulate heuristics in this function
// For now, we only consider matrix-vector operation to be memory bound
private static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int m2Clen) {
return m1Rlen == 1 || m1Clen == 1 || m2Clen == 1;
}
/**
* Performs matrix multiplication using native library if BLAS is available or else falls back to
* Java BLAS.
*
* @param m1 lhs matrix block
* @param m2 rhs matrix block
* @param ret output matrix block
* @param k number of threads
* @throws DMLRuntimeException if error occurs
*/
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException {
matrixMult(m1, m2, ret, k, true);
}
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean examSparsity) throws DMLRuntimeException {
// Sanity check:
k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
// check inputs / outputs
if (m1.isEmptyBlock() || m2.isEmptyBlock()) {
ret.setNonZeros(0);
if(examSparsity)
ret.examSparsity(); // turn empty dense into sparse
return;
}
if (NativeHelper.isNativeLibraryLoaded() &&
!isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) && !m1.isInSparseFormat() && !m2.isInSparseFormat()) {
ret.sparse = false;
ret.allocateDenseBlock();
long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
if (NativeHelper.matrixMultDenseDense(m1.denseBlock, m2.denseBlock,
ret.denseBlock, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k)) {
if(DMLScript.STATISTICS) {
Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
Statistics.numNativeLibMatrixMultCalls.increment();
}
ret.recomputeNonZeros();
// post-processing (nnz maintained in parallel)
if(examSparsity)
ret.examSparsity();
return;
} else {
// Else fall back to Java
Statistics.incrementNativeFailuresCounter();
}
}
if (k == 1)
LibMatrixMult.matrixMult(m1, m2, ret, examSparsity);
else
LibMatrixMult.matrixMult(m1, m2, ret, k);
}
/**
* This method performs convolution (i.e. cross-correlation) operation on input
*
* @param input input batch
* @param filter filter
* @param outputBlock output of convolution
* @param params convolution parameters
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
if(NativeHelper.isNativeLibraryLoaded() && !input.isInSparseFormat() && !filter.isInSparseFormat()) {
setNumThreads(params);
if(params.bias == null) {
long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
int nnz = NativeHelper.conv2dDense(input.denseBlock, filter.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W,
params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w,
params.P, params.Q, params.numThreads);
if(nnz != -1) {
if(DMLScript.STATISTICS) {
Statistics.nativeConv2dTime += System.nanoTime() - start;
Statistics.numNativeConv2dCalls.increment();
}
// post-processing: maintain nnz
outputBlock.setNonZeros(nnz);
return;
}
else {
// Fall back to Java when failures
Statistics.incrementNativeFailuresCounter();
}
}
else {
if(params.bias.isInSparseFormat())
params.bias.sparseToDense(); // Bias matrix is usually extremely small
long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
int nnz = NativeHelper.conv2dBiasAddDense(input.denseBlock, params.bias.denseBlock, filter.denseBlock, outputBlock.denseBlock,
params.N, params.C, params.H, params.W,
params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w,
params.P, params.Q, params.numThreads);
if(nnz != -1) {
if(DMLScript.STATISTICS) {
Statistics.nativeConv2dTime += System.nanoTime() - start;
Statistics.numNativeConv2dCalls.increment();
}
// post-processing: maintain nnz
outputBlock.setNonZeros(nnz);
return;
}
else {
// Fall back to Java when failures
Statistics.incrementNativeFailuresCounter();
}
}
}
// Fall back to Java when failures or sparse
LibMatrixDNN.conv2d(input, filter, outputBlock, params);
}
private static void setNumThreads(ConvolutionParameters params) {
params.numThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
if (!(params.isOutputThreadSafe() && params.numThreads > 1))
params.numThreads = 1;
}
/**
* This method computes the backpropogation errors for filter of convolution operation
*
* @param input input image
* @param dout errors from next layer
* @param outputBlock output errors
* @param params convolution parameters
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
if(NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !input.isInSparseFormat()) {
setNumThreads(params);
long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
int nnz = NativeHelper.conv2dBackwardFilterDense(input.denseBlock, dout.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W,
params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w,
params.P, params.Q, params.numThreads);
if(nnz != -1) {
if(DMLScript.STATISTICS) {
Statistics.nativeConv2dBwdFilterTime += System.nanoTime() - start;
Statistics.numNativeConv2dBwdFilterCalls.increment();
}
// post-processing: maintain nnz
outputBlock.setNonZeros(nnz);
return;
}
else {
// Fall back to Java when failures
Statistics.incrementNativeFailuresCounter();
}
}
// Fall back to Java when failures or sparse
LibMatrixDNN.conv2dBackwardFilter(input, dout, outputBlock, params);
}
/**
* This method computes the backpropogation errors for previous layer of convolution operation
*
* @param filter filter used in conv2d
* @param dout errors from next layer
* @param outputBlock output errors
* @param params convolution parameters
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
if(NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !filter.isInSparseFormat()) {
setNumThreads(params);
long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
int nnz = NativeHelper.conv2dBackwardDataDense(filter.denseBlock, dout.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W,
params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w,
params.P, params.Q, params.numThreads);
if(nnz != -1) {
if(DMLScript.STATISTICS) {
Statistics.nativeConv2dBwdDataTime += System.nanoTime() - start;
Statistics.numNativeConv2dBwdDataCalls.increment();
}
// post-processing: maintain nnz
outputBlock.setNonZeros(nnz);
return;
}
else {
// Fall back to Java when failures
Statistics.incrementNativeFailuresCounter();
}
}
// Fall back to Java when failures or sparse
LibMatrixDNN.conv2dBackwardData(filter, dout, outputBlock, params);
}
}