Added support for Tensor Cores CC >= 7.0 (V100). For FP16/32 (mixed precision) define CUDNN_HALF should be used.

pull/492/head
AlexeyAB 7 years ago
parent 85eafd3d59
commit cad4d1618f
  1. 2
      build/darknet/darknet.vcxproj
  2. 118
      src/convolutional_kernels.cu
  3. 8
      src/convolutional_layer.c
  4. 1
      src/convolutional_layer.h
  5. 2
      src/layer.c
  6. 1
      src/layer.h
  7. 3
      src/network.c
  8. 2
      src/network_kernels.cu

@ -145,7 +145,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalLibraryDirectories>C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc12\lib;C:\opencv_2.4.13\opencv\build\x64\vc14\lib;$(CUDA_PATH)lib\$(PlatformName);$(cudnn)\lib\x64;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
<AdditionalLibraryDirectories>C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc14\lib;$(CUDA_PATH)lib\$(PlatformName);$(cudnn)\lib\x64;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
<AdditionalDependencies>..\..\3rdparty\lib\x64\pthreadVC2.lib;cublas.lib;curand.lib;cudart.lib;%(AdditionalDependencies)</AdditionalDependencies>
<OutputFile>$(OutDir)\$(TargetName)$(TargetExt)</OutputFile>
</Link>

@ -81,8 +81,8 @@ __global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
//if (idx < size) *((unsigned short *)output_f16 + idx) = __float2half(input_f32[idx]);
}
void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) {
cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16);
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) {
cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
}
__global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
@ -92,8 +92,8 @@ __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
//if (idx < size) output_f32[idx] = __half2float(*((unsigned short *)input_f16 + idx));
}
void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) {
cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32);
void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) {
cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
}
half *cuda_make_f16_from_f32_array(float *src, size_t n)
@ -102,7 +102,7 @@ half *cuda_make_f16_from_f32_array(float *src, size_t n)
size_t size = sizeof(half)*n;
check_error(cudaMalloc((void **)&dst16, size));
if (src) {
cuda_convert_f32_to_f16(src, n, dst16);
cuda_convert_f32_to_f16(src, n, (float *)dst16);
}
if (!dst16) error("Cuda malloc failed\n");
return dst16;
@ -124,7 +124,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
}
#ifdef CUDNN
//float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
float alpha = 1, beta = 0;
#ifdef CUDNN_HALF
@ -154,8 +154,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
}
cuda_convert_f32_to_f16(state.input, input16_size, input16);
cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16);
//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
cudnnConvolutionForward(cudnn_handle(),
&alpha,
l.srcTensorDesc,
@ -170,11 +171,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.dstTensorDesc,
output16);
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
cuda_convert_f16_to_f32((float *)output16, output16_size, l.output_gpu);
#else
cudnnConvolutionForward(cudnn_handle(),
&alpha,
&one,
l.srcTensorDesc,
state.input,
l.weightDesc,
@ -183,7 +185,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.fw_algo,
state.workspace,
l.workspace_size,
&beta,
&one,
l.dstTensorDesc,
l.output_gpu);
#endif
@ -231,6 +233,87 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if(l.xnor) state.input = l.binary_input_gpu;
#ifdef CUDNN
float one = 1;
float alpha = 1, beta = 0;
#ifdef CUDNN_HALF
const size_t input16_size = l.batch*l.c*l.w*l.h;
static size_t max_input16_size = input16_size;
static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size);
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
static size_t max_delta16_size = delta16_size;
static half* delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size);
if (max_input16_size < input16_size) {
max_input16_size = input16_size;
cuda_free((float *)input16);
input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size);
}
if (max_delta16_size < delta16_size) {
max_delta16_size = delta16_size;
cuda_free((float *)delta16);
delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size);
}
cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16);
cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, (float *)delta16);
// convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16
// get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16)
// calculate conv weight updates
// Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum
// so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
cudnnConvolutionBackwardFilter(cudnn_handle(),
&one,
l.srcTensorDesc,
input16, //state.input,
l.ddstTensorDesc,
delta16, //l.delta_gpu,
l.convDesc,
l.bf_algo,
state.workspace,
l.workspace_size,
&one,
l.dweightDesc,
l.weight_updates_gpu16); // l.weight_updates_gpu);
cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu);
if (state.delta) {
if (l.binary || l.xnor) swap_binary(&l);
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
// calculate delta for the next layer
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
cudnnConvolutionBackwardData(cudnn_handle(),
&alpha,
l.weightDesc,
l.weights_gpu16, //l.weights_gpu,
l.ddstTensorDesc,
delta16, //l.delta_gpu,
l.convDesc,
l.bd_algo,
state.workspace,
l.workspace_size,
&beta,
l.dsrcTensorDesc,
input16); // state.delta);
cuda_convert_f16_to_f32((float *)input16, input16_size, state.delta);
if (l.binary || l.xnor) swap_binary(&l);
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
}
#else // CUDNN_HALF
// calculate conv weight updates
// if used: beta=1 then loss decreases faster
cudnnConvolutionBackwardFilter(cudnn_handle(),
&one,
l.srcTensorDesc,
@ -248,6 +331,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if(state.delta){
if(l.binary || l.xnor) swap_binary(&l);
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
// calculate delta for the next layer
cudnnConvolutionBackwardData(cudnn_handle(),
&one,
l.weightDesc,
@ -265,7 +349,9 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if(l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
}
#else
#endif // CUDNN_HALF
#else // CUDNN
int m = l.n;
int n = l.size*l.size*l.c;
int k = l.out_w*l.out_h;
@ -318,7 +404,7 @@ void push_convolutional_layer(convolutional_layer layer)
{
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
#ifdef CUDNN_HALF
cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16);
cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, layer.weights_gpu16);
#endif
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
@ -358,6 +444,14 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
}else{
// update weights:
// weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) =
// weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) =
// weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125
//
// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum =
// (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 =
// weight_updates_gpu*0.9 - weights_gpu*0.2304
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);

@ -141,7 +141,8 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
{
#ifdef CUDNN_HALF
// TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0): Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100
// TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0):
// Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100
// PSEUDO_HALF_CONFIG is required for Tensor Cores - our case!
const cudnnDataType_t data_type = CUDNN_DATA_HALF;
#else
@ -164,10 +165,12 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
// on architectures with DP4A support (compute capability 6.1 and later).
//cudnnDataType_t data_type = CUDNN_DATA_INT8;
// backward delta
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
// forward
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
@ -302,7 +305,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
#ifdef CUDNN_HALF
l.weights_gpu16 = cuda_make_array(l.weights, c*n*size*size/2);
l.weights_gpu16 = cuda_make_array(l.weights, c*n*size*size / 2);
l.weight_updates_gpu16 = cuda_make_array(l.weight_updates, c*n*size*size / 2);
#endif
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);

@ -21,6 +21,7 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
#ifdef CUDNN
void cudnn_convolutional_setup(layer *l, int cudnn_preference);
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif
#endif

@ -83,6 +83,8 @@ void free_layer(layer l)
if (l.x_norm_gpu) cuda_free(l.x_norm_gpu);
if (l.weights_gpu) cuda_free(l.weights_gpu);
if (l.weight_updates_gpu) cuda_free(l.weight_updates_gpu);
if (l.weights_gpu16) cuda_free(l.weights_gpu16);
if (l.weight_updates_gpu16) cuda_free(l.weight_updates_gpu16);
if (l.biases_gpu) cuda_free(l.biases_gpu);
if (l.bias_updates_gpu) cuda_free(l.bias_updates_gpu);
if (l.scales_gpu) cuda_free(l.scales_gpu);

@ -243,6 +243,7 @@ struct layer{
float * weight_updates_gpu;
float * weights_gpu16;
float * weight_updates_gpu16;
float * biases_gpu;
float * bias_updates_gpu;

@ -316,6 +316,8 @@ void set_batch_network(network *net, int b)
net->layers[i].batch = b;
#ifdef CUDNN
if(net->layers[i].type == CONVOLUTIONAL){
cudnn_convolutional_setup(net->layers + i, cudnn_fastest);
/*
layer *l = net->layers + i;
cudnn_convolutional_setup(l, cudnn_fastest);
// check for excessive memory consumption
@ -327,6 +329,7 @@ void set_batch_network(network *net, int b)
cudnn_convolutional_setup(l, cudnn_smallest);
l->workspace_size = get_workspace_size(*l);
}
*/
}
#endif
}

@ -117,7 +117,7 @@ void forward_backward_network_gpu(network net, float *x, float *y)
int i;
for (i = 0; i < net.n; ++i) {
layer l = net.layers[i];
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, (half *)l.weights_gpu16);
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
}
#endif
forward_network_gpu(net, state);

Loading…
Cancel
Save