package org.nd4j.nativeblas;
import lombok.Getter;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.context.Nd4jContext;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Properties;
/**
* @author raver119@gmail.com
* @author saudet
*/
public class NativeOpsHolder {
private static Logger log = LoggerFactory.getLogger(NativeOpsHolder.class);
private static final NativeOpsHolder INSTANCE = new NativeOpsHolder();
@Getter
private final NativeOps deviceNativeOps;
private NativeOpsHolder() {
try {
Properties props = Nd4jContext.getInstance().getConf();
String name = System.getProperty(Nd4j.NATIVE_OPS, props.get(Nd4j.NATIVE_OPS).toString());
Class<? extends NativeOps> nativeOpsClazz = Class.forName(name).asSubclass(NativeOps.class);
deviceNativeOps = nativeOpsClazz.newInstance();
deviceNativeOps.initializeDevicesAndFunctions();
int numThreads;
String numThreadsString = System.getenv("OMP_NUM_THREADS");
if (numThreadsString != null && !numThreadsString.isEmpty()) {
numThreads = Integer.parseInt(numThreadsString);
deviceNativeOps.setOmpNumThreads(numThreads);
} else {
int cores = Loader.totalCores();
int chips = Loader.totalChips();
if (chips > 0 && cores > 0) {
deviceNativeOps.setOmpNumThreads(Math.max(1, cores / chips));
} else
deviceNativeOps.setOmpNumThreads(
deviceNativeOps.getCores(Runtime.getRuntime().availableProcessors()));
}
//deviceNativeOps.setOmpNumThreads(4);
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
} catch (Exception | Error e) {
throw new RuntimeException(
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
e);
}
}
public static NativeOpsHolder getInstance() {
return INSTANCE;
}
}