Sway conv-kernel

pull/4540/head
AlexeyAB 6 years ago
parent 7ae1ae5641
commit a08c872564
  1. 2
      include/darknet.h
  2. 5
      src/blas.h
  3. 205
      src/blas_kernels.cu
  4. 22
      src/conv_lstm_layer.c
  5. 49
      src/convolutional_kernels.cu
  6. 20
      src/convolutional_layer.c
  7. 2
      src/convolutional_layer.h
  8. 6
      src/crnn_layer.c
  9. 2
      src/maxpool_layer.c
  10. 5
      src/parser.c

@ -248,6 +248,7 @@ struct layer {
int truth;
float smooth;
float dot;
int sway;
float angle;
float jitter;
float saturation;
@ -542,6 +543,7 @@ struct layer {
float * x_norm_gpu;
float * weights_gpu;
float * weight_updates_gpu;
float * weight_deform_gpu;
float * weight_change_gpu;
float * weights_gpu16;

@ -129,6 +129,11 @@ void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_size,
void sam_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out);
void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups);
void expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups);
#endif
#ifdef __cplusplus
}

@ -1236,5 +1236,210 @@ extern "C" void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_si
in_scales_c, out_from_delta,
in_from_output, out_state_delta);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
const int i = index * kernel_area;
const int stage_step = (nweights / kernel_area) / 4; // 4 stages
const int stage_id = index / stage_step;
// nweights = (c / groups) * n * size * size;
// kernel_area = size*size
if (i < nweights)
{
// rotate left or right
if (reverse) angle = -angle;
const float cos_a = cosf(angle * 3.14159265 / 180);
const float sin_a = sinf(angle * 3.14159265 / 180);
const int x_c = kernel_size / 2;
const int y_c = kernel_size / 2;
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
// Xsource = x*cos(alpha) + y*sin(alpha)
// Ysource = -x*sin(alpha) + y*cos(alpha)
float x_s = x_c + (x - x_c)*cos_a + (y - y_c)*sin_a;
float y_s = y_c - (x - x_c)*sin_a + (y - y_c)*cos_a;
int x_0 = floor(x_s); // round down
int x_1 = ceil(x_s); // round up
if (x_0 == x_1) x_1 = x_0 + 1;
int y_0 = floor(y_s);
int y_1 = ceil(y_s);
if (y_0 == y_1) y_1 = y_0 + 1;
float c_x_0 = x_1 - x_s;
float c_x_1 = x_s - x_0;
float c_y_0 = y_1 - y_s;
float c_y_1 = y_s - y_0;
float val = 0;
if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
weight_deform_gpu[x + y*kernel_size + i] = val;
}
}
}
}
extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void sway_and_flip_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
const int i = index * kernel_area;
const int stage_step = (nweights / kernel_area) / 4; // 4 stages
const int stage_id = index / stage_step;
// nweights = (c / groups) * n * size * size;
// kernel_area = size*size
if (i < nweights)
{
if (stage_id == 0) {
// simple copy
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
weight_deform_gpu[x + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
}
}
}
else if (stage_id == 1 || stage_id == 2)
{
// rotate left or right
if (stage_id == 2) angle = -angle;
if (reverse) angle = -angle;
const float cos_a = cosf(angle * 3.14159265 / 180);
const float sin_a = sinf(angle * 3.14159265 / 180);
const int x_c = kernel_size / 2;
const int y_c = kernel_size / 2;
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
// Xsource = x*cos(alpha) + y*sin(alpha)
// Ysource = -x*sin(alpha) + y*cos(alpha)
float x_s = x_c + (x - x_c)*cos_a + (y - y_c)*sin_a;
float y_s = y_c - (x - x_c)*sin_a + (y - y_c)*cos_a;
int x_0 = floor(x_s); // round down
int x_1 = ceil(x_s); // round up
if (x_0 == x_1) x_1 = x_0 + 1;
int y_0 = floor(y_s);
int y_1 = ceil(y_s);
if (y_0 == y_1) y_1 = y_0 + 1;
float c_x_0 = x_1 - x_s;
float c_x_1 = x_s - x_0;
float c_y_0 = y_1 - y_s;
float c_y_1 = y_s - y_0;
float val = 0;
if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
weight_deform_gpu[x + y*kernel_size + i] = val;
}
}
}
else if (stage_id == 3)
{
// flip
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
weight_deform_gpu[(kernel_size - x - 1) + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
}
}
}
}
}
extern "C" void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
sway_and_flip_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void reduce_and_expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < current_size) {
float val = 0;
for (int i = 0; i < groups; ++i) {
val += src_gpu[index + i*current_size];
}
for (int i = 0; i < groups; ++i) {
dst_gpu[index + i*current_size] = val / groups;
}
}
}
extern "C" void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups)
{
const int current_size = size / groups;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(current_size, block_size);
reduce_and_expand_array_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_gpu, dst_gpu, current_size, groups);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < current_size) {
for (int i = 0; i < groups; ++i) {
dst_gpu[index + i*current_size] = src_gpu[index];
}
}
}
extern "C" void expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups)
{
const int current_size = size / groups;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(current_size, block_size);
expand_array_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_gpu, dst_gpu, current_size, groups);
CHECK_CUDA(cudaPeekAtLastError());
}

@ -67,44 +67,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// U
l.uf = (layer*)calloc(1, sizeof(layer));
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.uf->batch = batch;
if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size;
l.ui = (layer*)calloc(1, sizeof(layer));
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.ui->batch = batch;
if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size;
l.ug = (layer*)calloc(1, sizeof(layer));
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.ug->batch = batch;
if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size;
l.uo = (layer*)calloc(1, sizeof(layer));
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.uo->batch = batch;
if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size;
// W
l.wf = (layer*)calloc(1, sizeof(layer));
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.wf->batch = batch;
if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size;
l.wi = (layer*)calloc(1, sizeof(layer));
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.wi->batch = batch;
if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size;
l.wg = (layer*)calloc(1, sizeof(layer));
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.wg->batch = batch;
if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size;
l.wo = (layer*)calloc(1, sizeof(layer));
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.wo->batch = batch;
if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size;
@ -112,21 +112,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// V
l.vf = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.vf->batch = batch;
if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size;
}
l.vi = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.vi->batch = batch;
if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size;
}
l.vo = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.vo->batch = batch;
if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size;
}

@ -1188,6 +1188,38 @@ void push_convolutional_layer(convolutional_layer l)
void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay)
{
/*
for (int angle = 0; angle < 360; angle++) {
printf(" angle = %d \n", angle);
sway_and_flip_weights_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, angle, 0);
cuda_pull_array(l.weight_deform_gpu, l.weights, l.nweights);
visualize_convolutional_layer(l, "weights", NULL);
wait_key_cv(10);
}
*/
if (l.sway) {
//for (l.angle = 0; l.angle < 360; l.angle++)
//{
sway_and_flip_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle, 1);
//simple_copy_ongpu(l.nweights, l.weight_updates_gpu, l.weight_deform_gpu);
reduce_and_expand_array_gpu(l.weight_deform_gpu, l.weight_updates_gpu, l.nweights, 4);
//printf(" angle = %f \n", l.angle);
//cuda_pull_array(l.weight_updates_gpu, l.weights, l.nweights);
//visualize_convolutional_layer(l, "weights", NULL);
//wait_key_cv(10);
//}
}
float learning_rate = learning_rate_init*l.learning_rate_scale;
//float momentum = a.momentum;
//float decay = a.decay;
@ -1221,6 +1253,23 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
scal_ongpu(l.n, momentum, l.scale_updates_gpu, 1);
}
}
if (l.sway) {
//for (l.angle = 0; l.angle < 360; l.angle += 4)
//{
expand_array_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, 4);
//simple_copy_ongpu(l.nweights, l.weight_deform_gpu, l.weights_gpu);
sway_and_flip_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, l.angle, 0);
//printf(" angle = %f \n", l.angle);
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
//visualize_convolutional_layer(l, "weights", NULL);
//wait_key_cv(10);
//}
}
//if (l.clip) {
// constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1);
//}

@ -370,7 +370,7 @@ void free_convolutional_batchnorm(convolutional_layer *l)
}
}
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int train)
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int sway, int train)
{
int total_batch = batch*steps;
int i;
@ -388,6 +388,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
stride_x = stride_y = l.stride = l.stride_x = l.stride_y = 1; // use stride=1 in host-layer
}
l.sway = sway;
l.assisted_excitation = assisted_excitation;
l.share_layer = share_layer;
l.index = index;
@ -530,15 +531,20 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
#ifdef GPU
if (l.activation == SWISH || l.activation == MISH) {
l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*l.outputs);
}
l.forward_gpu = forward_convolutional_layer_gpu;
l.backward_gpu = backward_convolutional_layer_gpu;
l.update_gpu = update_convolutional_layer_gpu;
if(gpu_index >= 0){
if (l.activation == SWISH || l.activation == MISH) {
l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*l.outputs);
}
if (l.sway) l.weight_deform_gpu = cuda_make_array(NULL, l.nweights);
if (adam) {
l.m_gpu = cuda_make_array(l.m, l.nweights);
l.v_gpu = cuda_make_array(l.v, l.nweights);
@ -660,7 +666,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
blur_size = 2;
blur_pad = 0;
}
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0, train);
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0, 0, train);
const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
if (blur_size == 2) {
@ -715,7 +721,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer()
{
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0, 0);
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0, 0, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,
@ -1517,7 +1523,7 @@ image *visualize_convolutional_layer(convolutional_layer l, char *window, image
image dc = collapse_image_layers(delta, 1);
char buff[256];
sprintf(buff, "%s: Output", window);
//show_image(dc, buff);
show_image(dc, buff);
//save_image(dc, buff);
free_image(dc);
return single_weights;

@ -31,7 +31,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
void free_convolutional_batchnorm(convolutional_layer *l);
size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int train);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int sway, int train);
void denormalize_convolutional_layer(convolutional_layer l);
void set_specified_workspace_limit(convolutional_layer *l, size_t workspace_size_limit);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);

@ -51,17 +51,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float));
l.input_layer = (layer*)calloc(1, sizeof(layer));
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = (layer*)calloc(1, sizeof(layer));
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = (layer*)calloc(1, sizeof(layer));
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, 0, train);
l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;

@ -123,7 +123,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
blur_size = 2;
blur_pad = 0;
}
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0, train);
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0, 0, train);
const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
if (blur_size == 2) {

@ -199,11 +199,12 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int binary = option_find_int_quiet(options, "binary", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
int sway = option_find_int_quiet(options, "sway", 0);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation, params.train);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation, sway, params.train);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
layer.angle = option_find_float_quiet(options, "angle", 15);
if(params.net.adam){
layer.B1 = params.net.B1;

Loading…
Cancel
Save