@ -1240,7 +1240,7 @@ extern "C" void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_si
}
__global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
__global__ void smooth_ rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
@ -1296,12 +1296,12 @@ __global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weigh
}
extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
extern "C" void smooth_ rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
smooth_ rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
CHECK_CUDA(cudaPeekAtLastError());
}
@ -1396,6 +1396,90 @@ extern "C" void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *we
}
__global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
const int i = index * kernel_area;
const int stage_step = (nweights / kernel_area) / 4; // 4 stages
const int stage_id = index / stage_step;
// nweights = (c / groups) * n * size * size;
// kernel_area = size*size
if (i < nweights)
{
// if(reverse)
if (stage_id == 0) {
// simple copy
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = x + y*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 1)
{
// 90 degree clockwise rotation - 1
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = (kernel_size - 1 - y) + x*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 2)
{
// 180 degree clockwise rotation - 2
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = (kernel_size - 1 - x) + (kernel_size - 1 - y)*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 3)
{
// 270 degree clockwise rotation - 3
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = y + (kernel_size - 1 - x)*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
}
}
extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, reverse);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void reduce_and_expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;