/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.utils;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.HashMap;
import java.util.Vector;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.File;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.SystemUtils;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.OptimizerUtils;
/**
* This class helps in loading native library.
* By default, it first tries to load Intel MKL, else tries to load OpenBLAS.
*/
public class NativeHelper {
private static boolean isSystemMLLoaded = false;
private static final Log LOG = LogFactory.getLog(NativeHelper.class.getName());
private static HashMap<String, String> supportedArchitectures = new HashMap<String, String>();
public static String blasType;
private static int maxNumThreads = -1;
private static boolean setMaxNumThreads = false;
static {
// Note: we only support 64 bit Java on x86 and AMD machine
supportedArchitectures.put("x86_64", "x86_64");
supportedArchitectures.put("amd64", "x86_64");
}
private static boolean attemptedLoading = false;
private static String hintOnFailures = "";
// Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors
private static void init() {
// Only Linux supported for BLAS
if(!SystemUtils.IS_OS_LINUX)
return;
// attemptedLoading variable ensures that we don't try to load SystemML and other dependencies
// again and again especially in the parfor (hence the double-checking with synchronized).
if(!attemptedLoading) {
DMLConfig dmlConfig = ConfigurationManager.getDMLConfig();
// -------------------------------------------------------------------------------------
// We allow BLAS to be enabled or disabled or explicitly selected in one of the two ways:
// 1. DML Configuration: native.blas (boolean flag)
// 2. Environment variable: SYSTEMML_BLAS (can be set to mkl, openblas or none)
// The option 1 will be removed in later SystemML versions.
// The option 2 is useful for two reasons:
// - Developer testing of different BLAS
// - Provides fine-grained control. Certain machines could use mkl while others use openblas, etc.
String userSpecifiedBLAS = (dmlConfig == null) ? "auto" : dmlConfig.getTextValue(DMLConfig.NATIVE_BLAS).trim().toLowerCase();
if(userSpecifiedBLAS.equals("auto") || userSpecifiedBLAS.equals("mkl") || userSpecifiedBLAS.equals("openblas")) {
long start = System.nanoTime();
if(!supportedArchitectures.containsKey(SystemUtils.OS_ARCH)) {
LOG.info("Unsupported architecture for native BLAS:" + SystemUtils.OS_ARCH);
return;
}
synchronized(NativeHelper.class) {
if(!attemptedLoading) {
// -----------------------------------------------------------------------------
// =============================================================================
// By default, we will native.blas=true and we will attempt to load MKL first.
// If MKL is not enabled then we try to load OpenBLAS.
// If both MKL and OpenBLAS are not available we fall back to Java BLAS.
if(userSpecifiedBLAS.equals("auto")) {
blasType = isMKLAvailable() ? "mkl" : isOpenBLASAvailable() ? "openblas" : null;
if(blasType == null)
LOG.info("Unable to load either MKL or OpenBLAS due to " + hintOnFailures);
}
else if(userSpecifiedBLAS.equals("mkl")) {
blasType = isMKLAvailable() ? "mkl" : null;
if(blasType == null)
LOG.info("Unable to load MKL due to " + hintOnFailures);
}
else if(userSpecifiedBLAS.equals("openblas")) {
blasType = isOpenBLASAvailable() ? "openblas" : null;
if(blasType == null)
LOG.info("Unable to load OpenBLAS due to " + hintOnFailures);
}
else {
// Only thrown at development time.
throw new RuntimeException("Unsupported BLAS:" + userSpecifiedBLAS);
}
// =============================================================================
if(blasType != null && loadLibraryHelper("libsystemml_" + blasType + "-Linux-x86_64.so")) {
String blasPathAndHint = "";
// ------------------------------------------------------------
// This logic gets the list of native libraries that are loaded
if(LOG.isDebugEnabled()) {
// Only perform the checking of library paths when DEBUG is enabled to avoid runtime overhead.
try {
java.lang.reflect.Field loadedLibraryNamesField = ClassLoader.class.getDeclaredField("loadedLibraryNames");
loadedLibraryNamesField.setAccessible(true);
@SuppressWarnings("unchecked")
Vector<String> libraries = (Vector<String>) loadedLibraryNamesField.get(ClassLoader.getSystemClassLoader());
LOG.debug("List of native libraries loaded:" + libraries);
for(String library : libraries) {
if(library.contains("libmkl_rt") || library.contains("libopenblas")) {
blasPathAndHint = " from the path " + library;
break;
}
}
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
LOG.debug("Error while finding list of native libraries:" + e.getMessage());
}
}
// ------------------------------------------------------------
LOG.info("Using native blas: " + blasType + blasPathAndHint);
isSystemMLLoaded = true;
}
}
}
double timeToLoadInMilliseconds = (System.nanoTime()-start)*1e-6;
if(timeToLoadInMilliseconds > 1000)
LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds.");
}
else {
LOG.warn("Using internal Java BLAS as native BLAS support the configuration 'native.blas'=" + userSpecifiedBLAS + ".");
}
attemptedLoading = true;
}
}
public static boolean isNativeLibraryLoaded() {
init();
if(maxNumThreads == -1)
maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
if(isSystemMLLoaded && !setMaxNumThreads && maxNumThreads != -1) {
// This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as each has different tradeoffs.
// In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be fastest.
// We can revisit this decision later and hence I would not recommend removing this method.
setMaxNumThreads(maxNumThreads);
setMaxNumThreads = true;
}
return isSystemMLLoaded;
}
public static int getMaxNumThreads() {
if(maxNumThreads == -1)
maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
return maxNumThreads;
}
private static boolean isMKLAvailable() {
return loadBLAS("mkl_rt", null);
}
private static boolean isOpenBLASAvailable() {
if(!loadBLAS("gomp", "gomp required for loading OpenBLAS-enabled SystemML library"))
return false;
return loadBLAS("openblas", null);
}
private static boolean loadBLAS(String blas, String optionalMsg) {
try {
System.loadLibrary(blas);
return true;
}
catch (UnsatisfiedLinkError e) {
if(!hintOnFailures.contains(blas))
hintOnFailures = hintOnFailures + blas + " ";
if(optionalMsg != null)
LOG.debug("Unable to load " + blas + "(" + optionalMsg + "):" + e.getMessage());
else
LOG.debug("Unable to load " + blas + ":" + e.getMessage());
return false;
}
}
private static boolean loadLibraryHelper(String path) {
InputStream in = null; OutputStream out = null;
try {
// This logic is added because Java doesnot allow to load library from a resource file.
in = NativeHelper.class.getResourceAsStream("/lib/"+path);
if(in != null) {
File temp = File.createTempFile(path, "");
temp.deleteOnExit();
out = FileUtils.openOutputStream(temp);
IOUtils.copy(in, out);
in.close(); in = null;
out.close(); out = null;
System.load(temp.getAbsolutePath());
return true;
}
else
LOG.warn("No lib available in the jar:" + path);
} catch(IOException e) {
LOG.warn("Unable to load library " + path + " from resource:" + e.getMessage());
} finally {
if(out != null)
try {
out.close();
} catch (IOException e) {}
if(in != null)
try {
in.close();
} catch (IOException e) {}
}
return false;
}
// TODO: Add pmm, wsloss, mmchain, etc.
public static native boolean matrixMultDenseDense(double [] m1, double [] m2, double [] ret, int m1rlen, int m1clen, int m2clen, int numThreads);
private static native boolean tsmm(double [] m1, double [] ret, int m1rlen, int m1clen, boolean isLeftTranspose, int numThreads);
// ----------------------------------------------------------------------------------------------------------------
// LibMatrixDNN operations:
// N = number of images, C = number of channels, H = image height, W = image width
// K = number of filters, R = filter height, S = filter width
// TODO: case not handled: sparse filters (which will only be executed in Java). Since filters are relatively smaller, this is a low priority.
// Returns -1 if failures or returns number of nonzeros
// Called by ConvolutionCPInstruction if both input and filter are dense
public static native int conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
public static native int conv2dBiasAddDense(double [] input, double [] bias, double [] filter, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
// Called by ConvolutionCPInstruction if both input and filter are dense
public static native int conv2dBackwardFilterDense(double [] input, double [] dout, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
// If both filter and dout are dense, then called by ConvolutionCPInstruction
// Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse.
public static native int conv2dBackwardDataDense(double [] filter, double [] dout, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
// Currently only supported with numThreads = 1 and sparse input
// Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse.
public static native boolean conv2dBackwardFilterSparseDense(int apos, int alen, int[] aix, double[] avals, double [] rotatedDoutPtr, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
// Called by LibMatrixDNN's thread if input is sparse and filter is dense
public static native boolean conv2dSparse(int apos, int alen, int[] aix, double[] avals, double [] filter, double [] ret, int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
// ----------------------------------------------------------------------------------------------------------------
// This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as each has different tradeoffs.
// In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be fastest.
// We can revisit this decision later and hence I would not recommend removing this method.
private static native void setMaxNumThreads(int numThreads);
}