|
|
|
@ -744,7 +744,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state |
|
|
|
|
assert((l.nweights) > 0); |
|
|
|
|
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.nweights, l.weight_updates_gpu16); |
|
|
|
|
|
|
|
|
|
if (!state.net.adversarial) { |
|
|
|
|
if (!state.net.adversarial && !l.train_only_bn) { |
|
|
|
|
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), |
|
|
|
|
&one, |
|
|
|
|
l.srcTensorDesc16, |
|
|
|
@ -796,7 +796,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state |
|
|
|
|
backward_batchnorm_layer_gpu(l, state); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (!state.net.adversarial) { |
|
|
|
|
if (!state.net.adversarial && !l.train_only_bn) { |
|
|
|
|
// calculate conv weight updates |
|
|
|
|
// if used: beta=1 then loss decreases faster |
|
|
|
|
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), |
|
|
|
@ -857,17 +857,19 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state |
|
|
|
|
|
|
|
|
|
float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; |
|
|
|
|
|
|
|
|
|
//im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); |
|
|
|
|
im2col_gpu_ext(im, // input |
|
|
|
|
l.c / l.groups, // input channels |
|
|
|
|
l.h, l.w, // input size (h, w) |
|
|
|
|
l.size, l.size, // kernel size (h, w) |
|
|
|
|
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w) |
|
|
|
|
l.stride_y, l.stride_x, // stride (h, w) |
|
|
|
|
l.dilation, l.dilation, // dilation (h, w) |
|
|
|
|
state.workspace); // output |
|
|
|
|
//gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n); |
|
|
|
|
gemm_ongpu(0, 1, m, n, k, 1, a, k, b, k, 1, c, n); |
|
|
|
|
if (!state.net.adversarial && !l.train_only_bn) { |
|
|
|
|
//im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); |
|
|
|
|
im2col_gpu_ext(im, // input |
|
|
|
|
l.c / l.groups, // input channels |
|
|
|
|
l.h, l.w, // input size (h, w) |
|
|
|
|
l.size, l.size, // kernel size (h, w) |
|
|
|
|
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w) |
|
|
|
|
l.stride_y, l.stride_x, // stride (h, w) |
|
|
|
|
l.dilation, l.dilation, // dilation (h, w) |
|
|
|
|
state.workspace); // output |
|
|
|
|
//gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n); |
|
|
|
|
gemm_ongpu(0, 1, m, n, k, 1, a, k, b, k, 1, c, n); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (state.delta) { |
|
|
|
|
if (l.binary || l.xnor) swap_binary(&l); |
|
|
|
|