/*-
*
* * Copyright 2017 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.javacpp.cuda.*;
import static org.bytedeco.javacpp.cudnn.*;
/**
* Functionality shared by all cuDNN-based helpers.
*
* @author saudet
*/
@Slf4j
public abstract class BaseCudnnHelper {
protected static void checkCuda(int error) {
if (error != cudaSuccess) {
throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
}
}
protected static void checkCudnn(int status) {
if (status != CUDNN_STATUS_SUCCESS) {
throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
}
}
protected static class CudnnContext extends cudnnContext {
protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
Deallocator(CudnnContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
public CudnnContext() {
// insure that cuDNN initializes on the same device as ND4J for this thread
Nd4j.create(1);
AtomicAllocator.getInstance();
// This needs to be called in subclasses:
// createHandles();
// deallocator(new Deallocator(this));
}
public CudnnContext(CudnnContext c) {
super(c);
}
protected void createHandles() {
checkCudnn(cudnnCreate(this));
}
protected void destroyHandles() {
checkCudnn(cudnnDestroy(this));
}
}
protected static class DataCache extends Pointer {
static class Deallocator extends DataCache implements Pointer.Deallocator {
Deallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFree(this));
setNull();
}
}
static class HostDeallocator extends DataCache implements Pointer.Deallocator {
HostDeallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFreeHost(this));
setNull();
}
}
public DataCache() {}
public DataCache(long size) {
position = 0;
limit = capacity = size;
int error = cudaMalloc(this, size);
if (error != cudaSuccess) {
log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
+ "), proceeding with host memory");
checkCuda(cudaMallocHost(this, size));
deallocator(new HostDeallocator(this));
} else {
deallocator(new Deallocator(this));
}
}
public DataCache(DataCache c) {
super(c);
}
}
protected static class TensorArray extends PointerPointer<cudnnTensorStruct> {
static class Deallocator extends TensorArray implements Pointer.Deallocator {
Pointer owner;
Deallocator(TensorArray a, Pointer owner) {
this.address = a.address;
this.capacity = a.capacity;
this.owner = owner;
}
@Override
public void deallocate() {
for (int i = 0; i < capacity; i++) {
cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
checkCudnn(cudnnDestroyTensorDescriptor(t));
}
owner.deallocate();
owner = null;
setNull();
}
}
TensorArray() {}
TensorArray(long size) {
PointerPointer p = new PointerPointer(size);
p.deallocate(false);
this.address = p.address();
this.limit = p.limit();
this.capacity = p.capacity();
cudnnTensorStruct t = new cudnnTensorStruct();
for (int i = 0; i < capacity; i++) {
checkCudnn(cudnnCreateTensorDescriptor(t));
this.put(i, t);
}
deallocator(new Deallocator(this, p));
}
TensorArray(TensorArray a) {
super(a);
}
}
protected static final int tensorFormat = CUDNN_TENSOR_NCHW;
protected int dataType = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? CUDNN_DATA_DOUBLE
: Nd4j.dataType() == DataBuffer.Type.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
protected int dataTypeSize = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8
: Nd4j.dataType() == DataBuffer.Type.FLOAT ? 4 : 2;
protected Pointer alpha = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? new DoublePointer(1.0)
: Nd4j.dataType() == DataBuffer.Type.FLOAT ? new FloatPointer(1.0f)
: new ShortPointer(new short[] {(short) HalfIndexer.fromFloat(1.0f)});
protected Pointer beta = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? new DoublePointer(0.0)
: Nd4j.dataType() == DataBuffer.Type.FLOAT ? new FloatPointer(0.0f)
: new ShortPointer(new short[] {(short) HalfIndexer.fromFloat(0.0f)});;
protected SizeTPointer sizeInBytes = new SizeTPointer(1);
public boolean checkSupported() {
boolean supported = true;
if (Nd4j.dataType() == DataBuffer.Type.HALF) {
supported = false;
log.warn("Not supported: DataBuffer.Type.HALF");
}
return supported;
}
}