|
|
|
@ -146,8 +146,12 @@ void cudnn_convolutional_setup(layer *l) |
|
|
|
|
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
|
|
|
|
|
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
|
|
|
|
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
|
|
|
|
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); |
|
|
|
|
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), |
|
|
|
|
#if(CUDNN_MAJOR >= 6) |
|
|
|
|
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn 6.0
|
|
|
|
|
#else |
|
|
|
|
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); // cudnn 5.1
|
|
|
|
|
#endif |
|
|
|
|
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), |
|
|
|
|
l->srcTensorDesc, |
|
|
|
|
l->weightDesc, |
|
|
|
|
l->convDesc, |
|
|
|
|