Some fix for CUDNN_HALF

pull/2095/head
AlexeyAB 7 years ago
parent e9226be3ed
commit cb998db949
  1. 6
      src/convolutional_kernels.cu
  2. 47
      src/convolutional_layer.c

@ -295,7 +295,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//#ifdef CUDNN_HALF
//if (state.use_mixed_precision) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 2*state.net.burn_in))
if (state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in))
//if(state.index != 0)
{
//printf("\n CUDNN_HALF!!! state.index = %d \n", state.index);
@ -475,7 +476,8 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
//#ifdef CUDNN_HALF
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 2*state.net.burn_in))
if (state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in))
//if (state.index != 0)
{
const size_t input16_size = l.batch*l.c*l.w*l.h;

@ -136,6 +136,43 @@ size_t get_workspace_size(layer l){
return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
}
size_t get_workspace_size16(layer l) {
#if defined(CUDNN) && defined(CUDNN_HALF)
if (gpu_index >= 0) {
size_t most = 0;
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc16,
l.weightDesc16,
l.convDesc,
l.dstTensorDesc16,
l.fw_algo16,
&s);
if (s > most) most = s;
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
l.srcTensorDesc16,
l.ddstTensorDesc16,
l.convDesc,
l.dweightDesc16,
l.bf_algo16,
&s);
if (s > most) most = s;
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc16,
l.ddstTensorDesc16,
l.convDesc,
l.dsrcTensorDesc16,
l.bd_algo16,
&s);
if (s > most) most = s;
return most;
}
#endif
return 0;
//if (l.xnor) return (size_t)l.bit_align*l.size*l.size*l.c * sizeof(float);
//return (size_t)l.out_h*l.out_w*l.size*l.size*l.c * sizeof(float);
}
#ifdef GPU
#ifdef CUDNN
void cudnn_convolutional_setup(layer *l, int cudnn_preference)
@ -177,7 +214,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
#ifdef CUDNN_HALF
//#ifdef CUDNN_HALF
// backward delta
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
@ -190,7 +227,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
// batch norm
cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
#endif
//#endif
// batch norm
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
@ -429,6 +466,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
}
#endif
l.workspace_size = get_workspace_size(l);
size_t workspace_size16 = get_workspace_size16(l);
if (workspace_size16 > l.workspace_size) l.workspace_size = workspace_size16;
l.activation = activation;
//fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
@ -532,6 +571,8 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
#endif
#endif
l->workspace_size = get_workspace_size(*l);
size_t workspace_size16 = get_workspace_size16(*l);
if (workspace_size16 > l->workspace_size) l->workspace_size = workspace_size16;
#ifdef CUDNN
// check for excessive memory consumption
@ -542,6 +583,8 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
printf(" used slow CUDNN algo without Workspace! Need memory: %zu, available: %zu\n", l->workspace_size, (free_byte < total_byte/2) ? free_byte : total_byte/2);
cudnn_convolutional_setup(l, cudnn_smallest);
l->workspace_size = get_workspace_size(*l);
size_t workspace_size16 = get_workspace_size16(*l);
if (workspace_size16 > l->workspace_size) l->workspace_size = workspace_size16;
}
#endif
}

Loading…
Cancel
Save