/*- * * * Copyright 2016 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.*; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdFilterAlgo; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.FwdAlgo; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.BaseCudnnHelper; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.context.CudaContext; import static org.bytedeco.javacpp.cuda.*; import static org.bytedeco.javacpp.cudnn.*; /** * cuDNN-based helper for the convolution layer. * * @author saudet */ @Slf4j public class CudnnConvolutionHelper extends BaseCudnnHelper implements ConvolutionHelper { private static class CudnnConvolutionContext extends CudnnContext { private static class Deallocator extends CudnnConvolutionContext implements Pointer.Deallocator { Deallocator(CudnnConvolutionContext c) { super(c); } @Override public void deallocate() { destroyHandles(); } } private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct(); private cudnnFilterStruct filterDesc = new cudnnFilterStruct(); private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct(); private cudnnActivationStruct activationDesc = new cudnnActivationStruct(); public CudnnConvolutionContext() { createHandles(); deallocator(new Deallocator(this)); } public CudnnConvolutionContext(CudnnConvolutionContext c) { super(c); srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc); dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc); biasTensorDesc = new cudnnTensorStruct(c.biasTensorDesc); deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc); filterDesc = new cudnnFilterStruct(c.filterDesc); convDesc = new cudnnConvolutionStruct(c.convDesc); activationDesc = new cudnnActivationStruct(c.activationDesc); } @Override protected void createHandles() { super.createHandles(); checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc)); checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc)); checkCudnn(cudnnCreateTensorDescriptor(biasTensorDesc)); checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc)); checkCudnn(cudnnCreateFilterDescriptor(filterDesc)); checkCudnn(cudnnCreateConvolutionDescriptor(convDesc)); checkCudnn(cudnnCreateActivationDescriptor(activationDesc)); } @Override protected void destroyHandles() { checkCudnn(cudnnDestroyActivationDescriptor(activationDesc)); checkCudnn(cudnnDestroyConvolutionDescriptor(convDesc)); checkCudnn(cudnnDestroyFilterDescriptor(filterDesc)); checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc)); checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc)); checkCudnn(cudnnDestroyTensorDescriptor(biasTensorDesc)); checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc)); super.destroyHandles(); } } private CudnnConvolutionContext cudnnContext = new CudnnConvolutionContext(); private DataCache workSpace = new DataCache(); @Override public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode) { int miniBatch = input.size(0); int inH = input.size(2); int inW = input.size(3); int outDepth = weights.size(0); int inDepth = weights.size(1); int kH = weights.size(2); int kW = weights.size(3); int[] outSize; if (convolutionMode == ConvolutionMode.Same) { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode); //Also performs validation pad = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {inH, inW}, kernel, strides); } else { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode); //Also performs validation } int outH = outSize[0]; int outW = outSize[1]; if (!Shape.strideDescendingCAscendingF(delta)) { // apparently not supported by cuDNN delta = delta.dup(); } int[] srcStride = input.stride(); int[] deltaStride = delta.stride(); int[] algo1 = new int[1]; int[] algo2 = new int[1]; if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, outDepth, outH, outW, deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3])); checkCudnn(cudnnSetConvolution2dDescriptor_v5(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION, dataType)); checkCudnn(cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW)); if (mode == AlgoMode.USER_SPECIFIED && bwdFilterAlgo != null && bwdDataAlgo != null) { switch (bwdFilterAlgo) { case ALGO_0: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; break; case ALGO_1: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; break; case FFT: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT; break; case ALGO_3: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3; break; case WINOGRAD: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD; break; case WINOGRAD_NONFUSED: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED; break; case FFT_TILING: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING; break; case COUNT: algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; break; default: throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo); } switch (bwdDataAlgo) { case ALGO_0: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; break; case ALGO_1: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; break; case FFT: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT; break; case FFT_TILING: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING; break; case WINOGRAD: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD; break; case WINOGRAD_NONFUSED: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED; break; case COUNT: algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; break; default: throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo); } } else { checkCudnn(cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, algo1)); checkCudnn(cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, algo2)); } INDArray epsNext; if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceExternal)) { try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal).notifyScopeBorrowed()) { epsNext = Nd4j.create(new int[] {miniBatch, inDepth, inH, inW}, 'c'); } } else epsNext = Nd4j.create(new int[] {miniBatch, inDepth, inH, inW}, 'c'); int[] dstStride = epsNext.stride(); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView, biasGradView, delta, epsNext); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer filterGradData = allocator.getPointer(weightGradView, context); Pointer biasGradData = allocator.getPointer(biasGradView, context); Pointer deltaData = allocator.getPointer(delta, context); Pointer dstData = allocator.getPointer(epsNext, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, dstStride[0], dstStride[1], dstStride[2], dstStride[3])); checkCudnn(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], sizeInBytes)); long sizeInBytes1 = sizeInBytes.get(0); checkCudnn(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], sizeInBytes)); long sizeInBytes2 = sizeInBytes.get(0); if (sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) { workSpace.deallocate(); workSpace = new DataCache(Math.max(sizeInBytes1, sizeInBytes2)); } checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, outDepth, 1, 1)); checkCudnn(cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, cudnnContext.biasTensorDesc, biasGradData)); checkCudnn(cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData)); checkCudnn(cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData)); allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView, delta, epsNext); Gradient retGradient = new DefaultGradient(); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) context.syncOldStream(); return new Pair<>(retGradient, epsNext); } @Override public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode) { int miniBatch = input.size(0); int inH = input.size(2); int inW = input.size(3); int outDepth = weights.size(0); int inDepth = weights.size(1); int kH = weights.size(2); int kW = weights.size(3); int[] srcStride = input.stride(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); int[] outSize; if (convolutionMode == ConvolutionMode.Same) { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode); //Also performs validation pad = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {inH, inW}, kernel, strides); } else { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode); //Also performs validation } INDArray z; if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceExternal)) { try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal).notifyScopeBorrowed()) { z = Nd4j.createUninitialized(new int[]{miniBatch, outDepth, outSize[0], outSize[1]}); } } else z = Nd4j.createUninitialized(new int[]{miniBatch, outDepth, outSize[0], outSize[1]}); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); checkCudnn(cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW)); checkCudnn(cudnnSetConvolution2dDescriptor_v5(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION, dataType)); // find dimension of convolution output // checkCudnn(cudnnGetConvolution2dForwardOutputDim(cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w)); // INDArray z = Nd4j.createUninitialized(new int[]{n[0],c[0],h[0],w[0]},'c'); int[] algo = new int[1]; int[] dstStride = z.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, outDepth, outSize[0], outSize[1], dstStride[0], dstStride[1], dstStride[2], dstStride[3])); if (mode == AlgoMode.USER_SPECIFIED && fwdAlgo != null) { switch (fwdAlgo) { case IMPLICIT_GEMM: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; break; case IMPLICIT_PRECOMP_GEMM: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; break; case GEMM: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_GEMM; break; case DIRECT: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT; break; case FFT: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT; break; case FFT_TILING: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING; break; case WINOGRAD: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD; break; case WINOGRAD_NONFUSED: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED; break; case COUNT: algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; break; default: throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo); } } else { checkCudnn(cudnnGetConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_FWD_NO_WORKSPACE : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, algo)); } Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(z, input, weights, bias); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer biasData = allocator.getPointer(bias, context); Pointer dstData = allocator.getPointer(z, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], sizeInBytes)); if (sizeInBytes.get(0) > workSpace.capacity()) { workSpace.deallocate(); workSpace = new DataCache(sizeInBytes.get(0)); } checkCudnn(cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData)); checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, outDepth, 1, 1)); checkCudnn(cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, cudnnContext.dstTensorDesc, dstData)); allocator.registerAction(context, z, input, weights, bias); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) context.syncOldStream(); return z; } @Override public INDArray activate(INDArray z, IActivation afn) { if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); INDArray activation = z; Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(z); Pointer dstData = allocator.getPointer(z, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); switch (afn.toString()) { case "identity": break; case "sigmoid": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "relu": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "tanh": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "softmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "logsoftmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; default: activation = null; } allocator.registerAction(context, activation); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) context.syncOldStream(); return activation; } }