Stretch conv-kernel

pull/4540/head
AlexeyAB 6 years ago
parent b78aa3961b
commit f2fc239096
  1. 1
      src/blas.h
  2. 103
      src/blas_kernels.cu
  3. 10
      src/convolutional_kernels.cu

@ -130,6 +130,7 @@ void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_size,
void sam_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out); void sam_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out);
void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse); void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void stretch_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, float scale, int reverse);
void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse); void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse); void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse);
void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups); void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups);

@ -1324,6 +1324,109 @@ extern "C" void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *we
} }
__global__ void stretch_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, float scale, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
const int i = index * kernel_area;
const int stage_step = (nweights / kernel_area) / 4; // 4 stages
const int stage_id = index / stage_step;
// nweights = (c / groups) * n * size * size;
// kernel_area = size*size
if (i < nweights)
{
if (stage_id == 0) {
// simple copy
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
weight_deform_gpu[x + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
}
}
}
else if (stage_id > 0)
{
if (stage_id == 1) scale = 0.65;
else if (stage_id == 2) scale = 0.8;
else if (stage_id == 3) scale = 1.3;
if (reverse) scale = 1 / scale;
const int x_c = kernel_size / 2;
const int y_c = kernel_size / 2;
float dropout_sum = 0;
for (int y = 0; y < kernel_size; ++y) {
for (int x = 0; x < kernel_size; ++x) {
// Xsource = x_c + (x_d - x_c) / scale
// Ysource = y_c + (y_d - y_c) / scale
float x_s = x_c + (x - x_c) / scale;
float y_s = y_c + (y - y_c) / scale;
int x_0 = floor(x_s); // round down
int x_1 = ceil(x_s); // round up
if (x_0 == x_1) x_1 = x_0 + 1;
int y_0 = floor(y_s);
int y_1 = ceil(y_s);
if (y_0 == y_1) y_1 = y_0 + 1;
float c_x_0 = x_1 - x_s;
float c_x_1 = x_s - x_0;
float c_y_0 = y_1 - y_s;
float c_y_1 = y_s - y_0;
float val = 0;
if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
else dropout_sum += c_x_0 * c_y_0;
if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
else dropout_sum += c_x_1 * c_y_0;
if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
else dropout_sum += c_x_0 * c_y_1;
if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
else dropout_sum += c_x_1 * c_y_1;
weight_deform_gpu[x + y*kernel_size + i] = val;
}
}
// compensate for dropped items
//const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
for (int y = 0; y < kernel_size; ++y) {
for (int x = 0; x < kernel_size; ++x) {
weight_deform_gpu[x + y*kernel_size + i] /= scale;// *= coef;
}
}
}
}
}
extern "C" void stretch_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, float scale, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
stretch_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, scale, reverse);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void sway_and_flip_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse) __global__ void sway_and_flip_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
{ {
const int index = blockIdx.x*blockDim.x + threadIdx.x; const int index = blockIdx.x*blockDim.x + threadIdx.x;

@ -1202,18 +1202,21 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
if (l.deform) { if (l.deform) {
//for (l.angle = 0; l.angle < 360; l.angle++) //for (l.angle = 0; l.angle < 360; l.angle += 1)
//{ //{
//stretch_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle/180, 1);
//else simple_copy_ongpu(l.nweights, l.weight_updates_gpu, l.weight_deform_gpu);
if (l.rotate) rotate_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, 1); if (l.rotate) rotate_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, 1);
else if (l.sway) sway_and_flip_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle, 1); else if (l.sway) sway_and_flip_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle, 1);
else if (l.stretch) stretch_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, 0, 1);
//simple_copy_ongpu(l.nweights, l.weight_updates_gpu, l.weight_deform_gpu); //simple_copy_ongpu(l.nweights, l.weight_updates_gpu, l.weight_deform_gpu);
reduce_and_expand_array_gpu(l.weight_deform_gpu, l.weight_updates_gpu, l.nweights, 4); reduce_and_expand_array_gpu(l.weight_deform_gpu, l.weight_updates_gpu, l.nweights, 4);
//printf(" angle = %f \n", l.angle); //printf(" angle = %f \n", l.angle);
//cuda_pull_array(l.weight_updates_gpu, l.weights, l.nweights); //cuda_pull_array(l.weight_deform_gpu, l.weights, l.nweights);
//visualize_convolutional_layer(l, "weights", NULL); //visualize_convolutional_layer(l, "weights", NULL);
//wait_key_cv(10); //wait_key_cv(10);
//} //}
@ -1256,7 +1259,7 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
} }
if (l.deform) { if (l.deform) {
//for (l.angle = 0; l.angle < 50; l.angle += 0.1) //for (l.angle = 0; l.angle < 360; l.angle += 4)
//{ //{
expand_array_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, 4); expand_array_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, 4);
@ -1264,6 +1267,7 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
if (l.rotate) rotate_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, 0); if (l.rotate) rotate_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, 0);
else if (l.sway) sway_and_flip_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, l.angle, 0); else if (l.sway) sway_and_flip_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, l.angle, 0);
else if (l.stretch) stretch_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, 0, 0);
//printf(" angle = %f, reverse = %d \n", l.angle, 0); //printf(" angle = %f, reverse = %d \n", l.angle, 0);
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights); //cuda_pull_array(l.weights_gpu, l.weights, l.nweights);

Loading…
Cancel
Save