|
|
|
@ -1119,17 +1119,45 @@ extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__inline__ __device__ |
|
|
|
|
float warpAllReduceSum(float val) { |
|
|
|
|
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) |
|
|
|
|
#if CUDART_VERSION >= 9000 |
|
|
|
|
val += __shfl_xor_sync(0xffffffff, val, mask); |
|
|
|
|
#else |
|
|
|
|
val += __shfl_xor(val, mask); |
|
|
|
|
#endif |
|
|
|
|
return val; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size, |
|
|
|
|
float *in_scales_c, float *out_from_delta, |
|
|
|
|
float *in_from_output, float *out_state_delta) |
|
|
|
|
{ |
|
|
|
|
const int index = blockIdx.x*blockDim.x + threadIdx.x; |
|
|
|
|
int osd_index = index / channel_size; |
|
|
|
|
|
|
|
|
|
if (index < size) { |
|
|
|
|
out_state_delta[index / channel_size] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?) |
|
|
|
|
out_from_delta[index] += in_scales_c[index / channel_size] * in_w_h_c_delta[index]; // input * l.delta |
|
|
|
|
//out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?) |
|
|
|
|
|
|
|
|
|
int warp_id = index / 32; |
|
|
|
|
int index_warp_start = warp_id * 32; |
|
|
|
|
int osd_index_warp_start = index_warp_start / channel_size; |
|
|
|
|
int osd_index_warp_end = (index_warp_start + 31) / channel_size; |
|
|
|
|
|
|
|
|
|
if (osd_index_warp_start == osd_index_warp_end) // all thread in warp process the same channel |
|
|
|
|
{ |
|
|
|
|
float sum = warpAllReduceSum(in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from |
|
|
|
|
if (threadIdx.x % 32 == 0) { |
|
|
|
|
atomicAdd(&out_state_delta[osd_index], sum); |
|
|
|
|
//out_state_delta[osd_index] += sum; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else { |
|
|
|
|
atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
//out_state_delta[index / channel_size] += in_w_h_c_delta[index] / channel_size; |
|
|
|
|
//out_from_delta[index] = in_w_h_c_delta[index]; |
|
|
|
|
out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta // atomic isn't required here |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|