From be5d0d66933e50585688bc86bb42786de55893ab Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Tue, 3 Sep 2019 01:35:05 +0300 Subject: [PATCH] Added assisted_excitation=1 for [convolutional] layer on GPU --- include/darknet.h | 3 + src/conv_lstm_layer.c | 22 ++-- src/convolutional_kernels.cu | 191 +++++++++++++++++++++++++++++++++++ src/convolutional_layer.c | 25 +++-- src/convolutional_layer.h | 3 +- src/crnn_layer.c | 6 +- src/maxpool_layer.c | 2 +- src/parser.c | 6 +- 8 files changed, 232 insertions(+), 26 deletions(-) diff --git a/include/darknet.h b/include/darknet.h index a7a62b47..e78abe6a 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -537,6 +537,9 @@ struct layer { float * rand_gpu; float * squared_gpu; float * norms_gpu; + + float *gt_gpu; + float *a_avg_gpu; #ifdef CUDNN cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; cudnnTensorDescriptor_t srcTensorDesc16, dstTensorDesc16; diff --git a/src/conv_lstm_layer.c b/src/conv_lstm_layer.c index a6da3bf0..4ae67b44 100644 --- a/src/conv_lstm_layer.c +++ b/src/conv_lstm_layer.c @@ -66,44 +66,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i // U l.uf = (layer*)calloc(1, sizeof(layer)); - *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.uf->batch = batch; if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; l.ui = (layer*)calloc(1, sizeof(layer)); - *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.ui->batch = batch; if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; l.ug = (layer*)calloc(1, sizeof(layer)); - *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.ug->batch = batch; if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; l.uo = (layer*)calloc(1, sizeof(layer)); - *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.uo->batch = batch; if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; // W l.wf = (layer*)calloc(1, sizeof(layer)); - *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.wf->batch = batch; if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; l.wi = (layer*)calloc(1, sizeof(layer)); - *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.wi->batch = batch; if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; l.wg = (layer*)calloc(1, sizeof(layer)); - *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.wg->batch = batch; if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; l.wo = (layer*)calloc(1, sizeof(layer)); - *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.wo->batch = batch; if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; @@ -111,21 +111,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i // V l.vf = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.vf->batch = batch; if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size; } l.vi = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.vi->batch = batch; if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size; } l.vo = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.vo->batch = batch; if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size; } diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index b476ac76..566fb893 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -605,6 +605,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) fix_nan_and_inf(l.output_gpu, l.outputs*l.batch); } + if(l.assisted_excitation && state.train) assisted_excitation_forward_gpu(l, state); + if (l.antialiasing) { network_state s = { 0 }; s.train = state.train; @@ -890,6 +892,195 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state } } +static box float_to_box_stride(float *f, int stride) +{ + box b = { 0 }; + b.x = f[0]; + b.y = f[1 * stride]; + b.w = f[2 * stride]; + b.h = f[3 * stride]; + return b; +} + +__global__ void calc_avg_activation_kernel(float *src, float *dst, int size, int channels, int batches) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + int xy = i % size; + int b = i / size; + + if (i < size*batches) { + dst[i] = 0; + for (int c = 0; c < channels; ++c) { + dst[i] += src[xy + size*(c + channels*b)]; + } + dst[i] = dst[i] / channels; + } +} + +#include + +void calc_avg_activation_gpu(float *src, float *dst, int size, int channels, int batches) +{ + const int num_blocks = get_number_of_blocks(size*batches, BLOCK); + + std::cout << " size = " << size << ", channels = " << channels << ", batches = " << batches << std::endl; + calc_avg_activation_kernel << > > (src, dst, size, channels, batches); +} + + +__global__ void assisted_activation_kernel(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + int xy = i % size; + int b = i / size; + + if (b < batches) { + for (int c = 0; c < channels; ++c) { + output[xy + size*(c + channels*b)] += alpha * gt_gpu[i] * a_avg_gpu[i]; + } + } +} + +void assisted_activation_gpu(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches) +{ + const int num_blocks = get_number_of_blocks(size*batches, BLOCK); + + assisted_activation_kernel << > > (alpha, output, gt_gpu, a_avg_gpu, size, channels, batches); +} + +void assisted_excitation_forward_gpu(convolutional_layer l, network_state state) +{ + const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); + + // epoch + const float epoch = (float)(*state.net.seen) / state.net.train_images_num; + + // calculate alpha + //const float alpha = (1 + cos(3.141592 * iteration_num)) / (2 * state.net.max_batches); + //const float alpha = (1 + cos(3.141592 * epoch)) / (2 * state.net.max_batches); + const float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches)) / 2; + + //printf("\n epoch = %f, alpha = %f, seen = %d, max_batches = %d, train_images_num = %d \n", + // epoch, alpha, (*state.net.seen), state.net.max_batches, state.net.train_images_num); + + //const int size = l.outputs * l.batch; + + float *a_avg = (float *)calloc(l.out_w * l.out_h * l.batch, sizeof(float)); + float *gt = (float *)calloc(l.out_w * l.out_h * l.batch, sizeof(float)); + + int b; + int w, h, c; + + l.max_boxes = state.net.num_boxes; + l.truths = l.max_boxes*(4 + 1); + + int num_truth = l.batch*l.truths; + float *truth_cpu = (float *)calloc(num_truth, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, num_truth); + //cudaStreamSynchronize(get_cuda_stream()); + //CHECK_CUDA(cudaPeekAtLastError()); + + for (b = 0; b < l.batch; ++b) + { + // calculate G + int t; + for (t = 0; t < state.net.num_boxes; ++t) { + box truth = float_to_box_stride(truth_cpu + t*(4 + 1) + b*l.truths, 1); + if (!truth.x) break; // continue; + + int left = floor((truth.x - truth.w / 2) * l.out_w); + int right = ceil((truth.x + truth.w / 2) * l.out_w); + int top = floor((truth.y - truth.h / 2) * l.out_h); + int bottom = ceil((truth.y + truth.h / 2) * l.out_h); + + for (w = left; w <= right; w++) { + for (h = top; h < bottom; h++) { + gt[w + l.out_w * h + l.out_w*l.out_h*b] = 1; + } + } + } + } + + cuda_push_array(l.gt_gpu, gt, l.out_w * l.out_h * l.batch); + //cudaStreamSynchronize(get_cuda_stream()); + //CHECK_CUDA(cudaPeekAtLastError()); + + // calc avg_output on GPU - for whole batch + calc_avg_activation_gpu(l.output_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch); + //cudaStreamSynchronize(get_cuda_stream()); + //CHECK_CUDA(cudaPeekAtLastError()); + + // calc new output + assisted_activation_gpu(alpha, l.output_gpu, l.gt_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch); + //cudaStreamSynchronize(get_cuda_stream()); + //CHECK_CUDA(cudaPeekAtLastError()); + + + + /* + for (b = 0; b < l.batch; ++b) + { + // calculate average A + for (w = 0; w < l.out_w; w++) { + for (h = 0; h < l.out_h; h++) { + for (c = 0; c < l.out_c; c++) { + a_avg[w + l.out_w*(h + l.out_h*b)] += l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))]; + } + a_avg[w + l.out_w*(h + l.out_h*b)] /= l.out_c; // a_avg / d + } + } + } + + // change activation + for (b = 0; b < l.batch; ++b) + { + for (w = 0; w < l.out_w; w++) { + for (h = 0; h < l.out_h; h++) { + for (c = 0; c < l.out_c; c++) + { + // a = a + alpha(t) + e(c,i,j) = a + alpha(t) + g(i,j) * avg_a(i,j) / channels + l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))] += + alpha * + g[w + l.out_w*(h + l.out_h*b)] * + a_avg[w + l.out_w*(h + l.out_h*b)]; + + //l.output[w + l.out_w*(h + l.out_h*(c + l.out_c*b))] = + // alpha * g[w + l.out_w*(h + l.out_h*b)] * a_avg[w + l.out_w*(h + l.out_h*b)]; + } + } + } + } + */ + + if (0) // visualize ground truth + { +#ifdef OPENCV + cuda_pull_array(l.output_gpu, l.output, l.outputs * l.batch); + cudaStreamSynchronize(get_cuda_stream()); + CHECK_CUDA(cudaPeekAtLastError()); + + for (b = 0; b < l.batch; ++b) + { + image img = float_to_image(l.out_w, l.out_h, 1, >[l.out_w*l.out_h*b]); + char buff[100]; + sprintf(buff, "a_excitation_%d", b); + show_image_cv(img, buff); + + image img2 = float_to_image(l.out_w, l.out_h, 1, &l.output[l.out_w*l.out_h*l.out_c*b]); + char buff2[100]; + sprintf(buff2, "a_excitation_act_%d", b); + show_image_cv(img2, buff2); + wait_key_cv(5); + } + wait_until_press_key_cv(); +#endif // OPENCV + } + + free(truth_cpu); + free(gt); + free(a_avg); +} + void pull_convolutional_layer(convolutional_layer l) { cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index c5c59576..157058eb 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -332,7 +332,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) #endif #endif -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer) +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation) { int total_batch = batch*steps; int i; @@ -349,6 +349,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, stride_x = stride_y = l.stride = l.stride_x = l.stride_y = 1; // use stride=1 in host-layer } + l.assisted_excitation = assisted_excitation; l.share_layer = share_layer; l.index = index; l.h = h; @@ -503,7 +504,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, #ifdef CUDNN_HALF l.weights_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); l.weight_updates_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); -#endif +#endif // CUDNN_HALF l.biases_gpu = cuda_make_array(l.biases, n); l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); } @@ -547,19 +548,27 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); } + + if (l.assisted_excitation) + { + const int size = l.out_w * l.out_h * l.batch; + l.gt_gpu = cuda_make_array(NULL, size); + l.a_avg_gpu = cuda_make_array(NULL, size); + } #ifdef CUDNN create_convolutional_cudnn_tensors(&l); cudnn_convolutional_setup(&l, cudnn_fastest); -#endif +#endif // CUDNN } -#endif +#endif // GPU l.workspace_size = get_convolutional_workspace_size(l); //fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c); l.bflops = (2.0 * l.nweights * l.out_h*l.out_w) / 1000000000.; if (l.xnor && l.use_bin_output) fprintf(stderr, "convXB"); else if (l.xnor) fprintf(stderr, "convX "); - else if(l.share_layer) fprintf(stderr, "convS "); + else if (l.share_layer) fprintf(stderr, "convS "); + else if (l.assisted_excitation) fprintf(stderr, "convAE"); else fprintf(stderr, "conv "); if (groups > 1) fprintf(stderr, "%5d/%4d ", n, groups); @@ -579,7 +588,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, printf("AA: "); l.input_layer = (layer*)calloc(1, sizeof(layer)); const int blur_size = 3; - *(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL); + *(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0); const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size; int i; for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) { @@ -636,7 +645,7 @@ void denormalize_convolutional_layer(convolutional_layer l) void test_convolutional_layer() { - convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL); + convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0); l.batch_normalize = 1; float data[] = {1,1,1,1,1, 1,1,1,1,1, @@ -1236,7 +1245,7 @@ void assisted_excitation_forward(convolutional_layer l, network_state state) } } - if(0) // visualize ground truth + if(1) // visualize ground truth { #ifdef OPENCV for (b = 0; b < l.batch; ++b) diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 1012663a..0072ce54 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -30,7 +30,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); #endif size_t get_convolutional_workspace_size(layer l); -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer); +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation); void denormalize_convolutional_layer(convolutional_layer l); void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); @@ -57,6 +57,7 @@ int convolutional_out_width(convolutional_layer layer); void rescale_weights(convolutional_layer l, float scale, float trans); void rgbgr_weights(convolutional_layer l); void assisted_excitation_forward(convolutional_layer l, network_state state); +void assisted_excitation_forward_gpu(convolutional_layer l, network_state state); #ifdef __cplusplus } diff --git a/src/crnn_layer.c b/src/crnn_layer.c index e3114fc9..588db741 100644 --- a/src/crnn_layer.c +++ b/src/crnn_layer.c @@ -50,17 +50,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float)); l.input_layer = (layer*)calloc(1, sizeof(layer)); - *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.input_layer->batch = batch; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; l.self_layer = (layer*)calloc(1, sizeof(layer)); - *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.self_layer->batch = batch; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; l.output_layer = (layer*)calloc(1, sizeof(layer)); - *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL); + *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); l.output_layer->batch = batch; if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 12392621..27d33860 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -107,7 +107,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s printf("AA: "); l.input_layer = (layer*)calloc(1, sizeof(layer)); const int blur_size = 3; - *(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL); + *(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0); const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size; int i; for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) { diff --git a/src/parser.c b/src/parser.c index b89bf0ac..97d6aef9 100644 --- a/src/parser.c +++ b/src/parser.c @@ -170,6 +170,8 @@ convolutional_layer parse_convolutional(list *options, size_params params, netwo char *activation_s = option_find_str(options, "activation", "logistic"); ACTIVATION activation = get_activation(activation_s); + int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0); + int share_index = option_find_int_quiet(options, "share_index", -1000000000); convolutional_layer *share_layer = NULL; if(share_index >= 0) share_layer = &net.layers[share_index]; @@ -186,10 +188,10 @@ convolutional_layer parse_convolutional(list *options, size_params params, netwo int xnor = option_find_int_quiet(options, "xnor", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0); - convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer); + convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation); layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.dot = option_find_float_quiet(options, "dot", 0); - layer.assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0); + if(params.net.adam){ layer.B1 = params.net.B1;