Added assisted_excitation=1 for [convolutional] layer on GPU

pull/3435/merge
AlexeyAB 6 years ago
parent a63782ca89
commit be5d0d6693
  1. 3
      include/darknet.h
  2. 22
      src/conv_lstm_layer.c
  3. 191
      src/convolutional_kernels.cu
  4. 25
      src/convolutional_layer.c
  5. 3
      src/convolutional_layer.h
  6. 6
      src/crnn_layer.c
  7. 2
      src/maxpool_layer.c
  8. 6
      src/parser.c

@ -537,6 +537,9 @@ struct layer {
float * rand_gpu;
float * squared_gpu;
float * norms_gpu;
float *gt_gpu;
float *a_avg_gpu;
#ifdef CUDNN
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
cudnnTensorDescriptor_t srcTensorDesc16, dstTensorDesc16;

@ -66,44 +66,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// U
l.uf = (layer*)calloc(1, sizeof(layer));
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.uf->batch = batch;
if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size;
l.ui = (layer*)calloc(1, sizeof(layer));
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.ui->batch = batch;
if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size;
l.ug = (layer*)calloc(1, sizeof(layer));
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.ug->batch = batch;
if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size;
l.uo = (layer*)calloc(1, sizeof(layer));
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.uo->batch = batch;
if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size;
// W
l.wf = (layer*)calloc(1, sizeof(layer));
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.wf->batch = batch;
if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size;
l.wi = (layer*)calloc(1, sizeof(layer));
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.wi->batch = batch;
if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size;
l.wg = (layer*)calloc(1, sizeof(layer));
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.wg->batch = batch;
if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size;
l.wo = (layer*)calloc(1, sizeof(layer));
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.wo->batch = batch;
if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size;
@ -111,21 +111,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// V
l.vf = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.vf->batch = batch;
if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size;
}
l.vi = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.vi->batch = batch;
if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size;
}
l.vo = (layer*)calloc(1, sizeof(layer));
if (l.peephole) {
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.vo->batch = batch;
if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size;
}

@ -605,6 +605,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
fix_nan_and_inf(l.output_gpu, l.outputs*l.batch);
}
if(l.assisted_excitation && state.train) assisted_excitation_forward_gpu(l, state);
if (l.antialiasing) {
network_state s = { 0 };
s.train = state.train;
@ -890,6 +892,195 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
}
}
static box float_to_box_stride(float *f, int stride)
{
box b = { 0 };
b.x = f[0];
b.y = f[1 * stride];
b.w = f[2 * stride];
b.h = f[3 * stride];
return b;
}
__global__ void calc_avg_activation_kernel(float *src, float *dst, int size, int channels, int batches)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int xy = i % size;
int b = i / size;
if (i < size*batches) {
dst[i] = 0;
for (int c = 0; c < channels; ++c) {
dst[i] += src[xy + size*(c + channels*b)];
}
dst[i] = dst[i] / channels;
}
}
#include <iostream>
void calc_avg_activation_gpu(float *src, float *dst, int size, int channels, int batches)
{
const int num_blocks = get_number_of_blocks(size*batches, BLOCK);
std::cout << " size = " << size << ", channels = " << channels << ", batches = " << batches << std::endl;
calc_avg_activation_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (src, dst, size, channels, batches);
}
__global__ void assisted_activation_kernel(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int xy = i % size;
int b = i / size;
if (b < batches) {
for (int c = 0; c < channels; ++c) {
output[xy + size*(c + channels*b)] += alpha * gt_gpu[i] * a_avg_gpu[i];
}
}
}
void assisted_activation_gpu(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches)
{
const int num_blocks = get_number_of_blocks(size*batches, BLOCK);
assisted_activation_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (alpha, output, gt_gpu, a_avg_gpu, size, channels, batches);
}
void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
{
const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
// epoch
const float epoch = (float)(*state.net.seen) / state.net.train_images_num;
// calculate alpha
//const float alpha = (1 + cos(3.141592 * iteration_num)) / (2 * state.net.max_batches);
//const float alpha = (1 + cos(3.141592 * epoch)) / (2 * state.net.max_batches);
const float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches)) / 2;
//printf("\n epoch = %f, alpha = %f, seen = %d, max_batches = %d, train_images_num = %d \n",
// epoch, alpha, (*state.net.seen), state.net.max_batches, state.net.train_images_num);
//const int size = l.outputs * l.batch;
float *a_avg = (float *)calloc(l.out_w * l.out_h * l.batch, sizeof(float));
float *gt = (float *)calloc(l.out_w * l.out_h * l.batch, sizeof(float));
int b;
int w, h, c;
l.max_boxes = state.net.num_boxes;
l.truths = l.max_boxes*(4 + 1);
int num_truth = l.batch*l.truths;
float *truth_cpu = (float *)calloc(num_truth, sizeof(float));
cuda_pull_array(state.truth, truth_cpu, num_truth);
//cudaStreamSynchronize(get_cuda_stream());
//CHECK_CUDA(cudaPeekAtLastError());
for (b = 0; b < l.batch; ++b)
{
// calculate G
int t;
for (t = 0; t < state.net.num_boxes; ++t) {
box truth = float_to_box_stride(truth_cpu + t*(4 + 1) + b*l.truths, 1);
if (!truth.x) break; // continue;
int left = floor((truth.x - truth.w / 2) * l.out_w);
int right = ceil((truth.x + truth.w / 2) * l.out_w);
int top = floor((truth.y - truth.h / 2) * l.out_h);
int bottom = ceil((truth.y + truth.h / 2) * l.out_h);
for (w = left; w <= right; w++) {
for (h = top; h < bottom; h++) {
gt[w + l.out_w * h + l.out_w*l.out_h*b] = 1;
}
}
}
}
cuda_push_array(l.gt_gpu, gt, l.out_w * l.out_h * l.batch);
//cudaStreamSynchronize(get_cuda_stream());
//CHECK_CUDA(cudaPeekAtLastError());
// calc avg_output on GPU - for whole batch
calc_avg_activation_gpu(l.output_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch);
//cudaStreamSynchronize(get_cuda_stream());
//CHECK_CUDA(cudaPeekAtLastError());
// calc new output
assisted_activation_gpu(alpha, l.output_gpu, l.gt_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch);
//cudaStreamSynchronize(get_cuda_stream());
//CHECK_CUDA(cudaPeekAtLastError());
/*
for (b = 0; b < l.batch; ++b)
{
// calculate average A
for (w = 0; w < l.out_w; w++) {
for (h = 0; h < l.out_h; h++) {
for (c = 0; c < l.out_c; c++) {
a_avg[w + l.out_w*(h + l.out_h*b)] += l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))];
}
a_avg[w + l.out_w*(h + l.out_h*b)] /= l.out_c; // a_avg / d
}
}
}
// change activation
for (b = 0; b < l.batch; ++b)
{
for (w = 0; w < l.out_w; w++) {
for (h = 0; h < l.out_h; h++) {
for (c = 0; c < l.out_c; c++)
{
// a = a + alpha(t) + e(c,i,j) = a + alpha(t) + g(i,j) * avg_a(i,j) / channels
l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))] +=
alpha *
g[w + l.out_w*(h + l.out_h*b)] *
a_avg[w + l.out_w*(h + l.out_h*b)];
//l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))] =
// alpha * g[w + l.out_w*(h + l.out_h*b)] * a_avg[w + l.out_w*(h + l.out_h*b)];
}
}
}
}
*/
if (0) // visualize ground truth
{
#ifdef OPENCV
cuda_pull_array(l.output_gpu, l.output, l.outputs * l.batch);
cudaStreamSynchronize(get_cuda_stream());
CHECK_CUDA(cudaPeekAtLastError());
for (b = 0; b < l.batch; ++b)
{
image img = float_to_image(l.out_w, l.out_h, 1, &gt[l.out_w*l.out_h*b]);
char buff[100];
sprintf(buff, "a_excitation_%d", b);
show_image_cv(img, buff);
image img2 = float_to_image(l.out_w, l.out_h, 1, &l.output[l.out_w*l.out_h*l.out_c*b]);
char buff2[100];
sprintf(buff2, "a_excitation_act_%d", b);
show_image_cv(img2, buff2);
wait_key_cv(5);
}
wait_until_press_key_cv();
#endif // OPENCV
}
free(truth_cpu);
free(gt);
free(a_avg);
}
void pull_convolutional_layer(convolutional_layer l)
{
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);

@ -332,7 +332,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer)
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation)
{
int total_batch = batch*steps;
int i;
@ -349,6 +349,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
stride_x = stride_y = l.stride = l.stride_x = l.stride_y = 1; // use stride=1 in host-layer
}
l.assisted_excitation = assisted_excitation;
l.share_layer = share_layer;
l.index = index;
l.h = h;
@ -503,7 +504,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
#ifdef CUDNN_HALF
l.weights_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1);
l.weight_updates_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1);
#endif
#endif // CUDNN_HALF
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
}
@ -547,19 +548,27 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
}
if (l.assisted_excitation)
{
const int size = l.out_w * l.out_h * l.batch;
l.gt_gpu = cuda_make_array(NULL, size);
l.a_avg_gpu = cuda_make_array(NULL, size);
}
#ifdef CUDNN
create_convolutional_cudnn_tensors(&l);
cudnn_convolutional_setup(&l, cudnn_fastest);
#endif
#endif // CUDNN
}
#endif
#endif // GPU
l.workspace_size = get_convolutional_workspace_size(l);
//fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
l.bflops = (2.0 * l.nweights * l.out_h*l.out_w) / 1000000000.;
if (l.xnor && l.use_bin_output) fprintf(stderr, "convXB");
else if (l.xnor) fprintf(stderr, "convX ");
else if(l.share_layer) fprintf(stderr, "convS ");
else if (l.share_layer) fprintf(stderr, "convS ");
else if (l.assisted_excitation) fprintf(stderr, "convAE");
else fprintf(stderr, "conv ");
if (groups > 1) fprintf(stderr, "%5d/%4d ", n, groups);
@ -579,7 +588,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
printf("AA: ");
l.input_layer = (layer*)calloc(1, sizeof(layer));
const int blur_size = 3;
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL);
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0);
const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {
@ -636,7 +645,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer()
{
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL);
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,
@ -1236,7 +1245,7 @@ void assisted_excitation_forward(convolutional_layer l, network_state state)
}
}
if(0) // visualize ground truth
if(1) // visualize ground truth
{
#ifdef OPENCV
for (b = 0; b < l.batch; ++b)

@ -30,7 +30,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif
size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation);
void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
@ -57,6 +57,7 @@ int convolutional_out_width(convolutional_layer layer);
void rescale_weights(convolutional_layer l, float scale, float trans);
void rgbgr_weights(convolutional_layer l);
void assisted_excitation_forward(convolutional_layer l, network_state state);
void assisted_excitation_forward_gpu(convolutional_layer l, network_state state);
#ifdef __cplusplus
}

@ -50,17 +50,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float));
l.input_layer = (layer*)calloc(1, sizeof(layer));
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = (layer*)calloc(1, sizeof(layer));
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = (layer*)calloc(1, sizeof(layer));
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL);
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0);
l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;

@ -107,7 +107,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
printf("AA: ");
l.input_layer = (layer*)calloc(1, sizeof(layer));
const int blur_size = 3;
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL);
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0);
const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {

@ -170,6 +170,8 @@ convolutional_layer parse_convolutional(list *options, size_params params, netwo
char *activation_s = option_find_str(options, "activation", "logistic");
ACTIVATION activation = get_activation(activation_s);
int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0);
int share_index = option_find_int_quiet(options, "share_index", -1000000000);
convolutional_layer *share_layer = NULL;
if(share_index >= 0) share_layer = &net.layers[share_index];
@ -186,10 +188,10 @@ convolutional_layer parse_convolutional(list *options, size_params params, netwo
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
layer.assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0);
if(params.net.adam){
layer.B1 = params.net.B1;

Loading…
Cancel
Save