diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 03c9ab79..135a2ea9 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -74,6 +74,38 @@ void binarize_weights_gpu(float *weights, int n, int size, float *binary) check_error(cudaPeekAtLastError()); } +__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) output_f16[idx] = input_f32[idx]; +} + +void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) { + cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16); +} + +__global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) output_f32[idx] = input_f16[idx]; +} + +void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) { + cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32); +} + +half *cuda_make_f16_from_f32_array(float *src, size_t n) +{ + half *dst16; + size_t size = sizeof(half)*n; + check_error(cudaMalloc((void **)&dst16, size)); + if (src) { + cuda_convert_f32_to_f16(src, n, dst16); + } + if (!dst16) error("Cuda malloc failed\n"); + return dst16; +} + void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) { fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); @@ -90,9 +122,57 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) } #ifdef CUDNN - float one = 1; + //float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT + float alpha = 1, beta = 0; + +#ifdef CUDNN_HALF + // Note: For improved performance it is advised to use beta[0] = 0.0. + // For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH; + // 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF + // 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED + // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops + + const size_t input16_size = l.batch*l.c*l.w*l.h; + static size_t max_input16_size = input16_size; + static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size); + + const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w; + static size_t max_output16_size = output16_size; + static half* output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size); + + if (max_input16_size < input16_size) { + max_input16_size = input16_size; + cuda_free((float *)input16); + input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size); + } + + if (max_output16_size < output16_size) { + max_output16_size = output16_size; + cuda_free((float *)output16); + output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size); + } + + cuda_convert_f32_to_f16(state.input, input16_size, input16); + + cudnnConvolutionForward(cudnn_handle(), + &alpha, + l.srcTensorDesc, + input16, + l.weightDesc, + l.weights_gpu16, + l.convDesc, + l.fw_algo, + state.workspace, + l.workspace_size, + &beta, + l.dstTensorDesc, + output16); + + cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu); +#else + cudnnConvolutionForward(cudnn_handle(), - &one, + &alpha, l.srcTensorDesc, state.input, l.weightDesc, @@ -101,9 +181,11 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.fw_algo, state.workspace, l.workspace_size, - &one, + &beta, l.dstTensorDesc, l.output_gpu); +#endif + #else int i; @@ -232,6 +314,9 @@ void pull_convolutional_layer(convolutional_layer layer) void push_convolutional_layer(convolutional_layer layer) { cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); +#ifdef CUDNN_HALF + cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16); +#endif cuda_push_array(layer.biases_gpu, layer.biases, layer.n); cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 801270ab..68658086 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -139,22 +139,38 @@ size_t get_workspace_size(layer l){ #ifdef CUDNN void cudnn_convolutional_setup(layer *l, int cudnn_preference) { - cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); - cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); +#ifdef CUDNN_HALF + // TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0): + // Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100 + const cudnnDataType_t data_type = CUDNN_DATA_HALF; +#else + cudnnDataType_t data_type = CUDNN_DATA_FLOAT; +#endif + // Tensor Core uses CUDNN_TENSOR_OP_MATH instead of CUDNN_DEFAULT_MATH + cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH); + + // INT8_CONFIG, INT8_EXT_CONFIG, INT8x4_CONFIG and INT8x4_EXT_CONFIG are only supported + // on architectures with DP4A support (compute capability 6.1 and later). + //cudnnDataType_t data_type = CUDNN_DATA_INT8; + + cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); + cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); + cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + + cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); + cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); + cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); #if(CUDNN_MAJOR >= 6) - cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn 6.0 + cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, data_type); // cudnn >= 6.0 #else cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); // cudnn 5.1 #endif int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; int backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST; - if (cudnn_preference == cudnn_smallest) { + if (cudnn_preference == cudnn_smallest) + { forward_algo = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; backward_algo = CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; @@ -275,6 +291,9 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int } l.weights_gpu = cuda_make_array(l.weights, c*n*size*size); +#ifdef CUDNN_HALF + l.weights_gpu16 = cuda_make_array(l.weights, c*n*size*size/2); +#endif l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size); l.biases_gpu = cuda_make_array(l.biases, n); diff --git a/src/layer.h b/src/layer.h index db012f1f..0f5addac 100644 --- a/src/layer.h +++ b/src/layer.h @@ -242,6 +242,8 @@ struct layer{ float * weights_gpu; float * weight_updates_gpu; + float * weights_gpu16; + float * biases_gpu; float * bias_updates_gpu;