From 978c2e00ed21e137ee97044fa43f6f08642866b9 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Thu, 26 Mar 2020 00:16:57 +0300 Subject: [PATCH] self-adversarial training --- include/darknet.h | 1 + src/batchnorm_layer.c | 41 +++++++++++++++++------ src/blas.h | 1 + src/blas_kernels.cu | 17 ++++++++-- src/convolutional_kernels.cu | 64 +++++++++++++++++++----------------- src/detector.c | 2 +- src/image_opencv.cpp | 4 +-- src/network_kernels.cu | 22 +++++++++++++ src/parser.c | 3 +- 9 files changed, 109 insertions(+), 46 deletions(-) diff --git a/include/darknet.h b/include/darknet.h index 464d3cf3..181af1c0 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -697,6 +697,7 @@ typedef struct network { float label_smooth_eps; int resize_step; int adversarial; + float adversarial_lr; int letter_box; float angle; float aspect; diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c index 7432434c..eeba5cc5 100644 --- a/src/batchnorm_layer.c +++ b/src/batchnorm_layer.c @@ -231,17 +231,17 @@ void update_batchnorm_layer(layer l, int batch, float learning_rate, float momen void pull_batchnorm_layer(layer l) { - cuda_pull_array(l.biases_gpu, l.biases, l.c); - cuda_pull_array(l.scales_gpu, l.scales, l.c); - cuda_pull_array(l.rolling_mean_gpu, l.rolling_mean, l.c); - cuda_pull_array(l.rolling_variance_gpu, l.rolling_variance, l.c); + cuda_pull_array(l.biases_gpu, l.biases, l.out_c); + cuda_pull_array(l.scales_gpu, l.scales, l.out_c); + cuda_pull_array(l.rolling_mean_gpu, l.rolling_mean, l.out_c); + cuda_pull_array(l.rolling_variance_gpu, l.rolling_variance, l.out_c); } void push_batchnorm_layer(layer l) { - cuda_push_array(l.biases_gpu, l.biases, l.c); - cuda_push_array(l.scales_gpu, l.scales, l.c); - cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.c); - cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c); + cuda_push_array(l.biases_gpu, l.biases, l.out_c); + cuda_push_array(l.scales_gpu, l.scales, l.out_c); + cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.out_c); + cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.out_c); } void forward_batchnorm_layer_gpu(layer l, network_state state) @@ -249,6 +249,13 @@ void forward_batchnorm_layer_gpu(layer l, network_state state) if (l.type == BATCHNORM) simple_copy_ongpu(l.outputs*l.batch, state.input, l.output_gpu); //copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); + if (state.net.adversarial) { + normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w); + scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w); + add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h); + return; + } + if (state.train) { simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_gpu); @@ -339,9 +346,23 @@ void forward_batchnorm_layer_gpu(layer l, network_state state) void backward_batchnorm_layer_gpu(layer l, network_state state) { + if (state.net.adversarial) { + inverse_variance_ongpu(l.out_c, l.rolling_variance_gpu, l.variance_gpu, 0.00001); + + scale_bias_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w); + scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w); + return; + } + if (!state.train) { - l.mean_gpu = l.rolling_mean_gpu; - l.variance_gpu = l.rolling_variance_gpu; + //l.mean_gpu = l.rolling_mean_gpu; + //l.variance_gpu = l.rolling_variance_gpu; + simple_copy_ongpu(l.out_c, l.rolling_mean_gpu, l.mean_gpu); +#ifdef CUDNN + inverse_variance_ongpu(l.out_c, l.rolling_variance_gpu, l.variance_gpu, 0.00001); +#else + simple_copy_ongpu(l.out_c, l.rolling_variance_gpu, l.variance_gpu); +#endif } #ifdef CUDNN diff --git a/src/blas.h b/src/blas.h index efbee945..819130a6 100644 --- a/src/blas.h +++ b/src/blas.h @@ -91,6 +91,7 @@ void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean); void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance, const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon); +void inverse_variance_ongpu(int size, float *src, float *dst, float epsilon); void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon); void compare_2_arrays_gpu(float *one, float *two, int size); void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index db862c24..e63d15a9 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -641,9 +641,9 @@ __global__ void fast_v_cbn_kernel(const float *x, float *mean, int batch, int f //if (max_minibatch_index == minibatch_index) { - rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter]; + if(rolling_mean_gpu) rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter]; - rolling_variance_gpu[filter] = alpha * variance_tmp + (1 - alpha) * rolling_variance_gpu[filter]; + if(rolling_variance_gpu) rolling_variance_gpu[filter] = alpha * variance_tmp + (1 - alpha) * rolling_variance_gpu[filter]; } } } @@ -655,6 +655,19 @@ extern "C" void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filte CHECK_CUDA(cudaPeekAtLastError()); } +__global__ void inverse_variance_kernel(int size, float *src, float *dst, float epsilon) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) + dst[index] = 1.0f / sqrtf(src[index] + epsilon); +} + +extern "C" void inverse_variance_ongpu(int size, float *src, float *dst, float epsilon) +{ + const int num_blocks = size / BLOCK + 1; + inverse_variance_kernel << > >(size, src, dst, epsilon); + CHECK_CUDA(cudaPeekAtLastError()); +} __global__ void normalize_scale_bias_kernel(int N, float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon) { diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 73be0c89..7c8dddf4 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -472,7 +472,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) if (l.batch_normalize) { - if (state.train) // Training + if (state.train && !state.net.adversarial) // Training { simple_copy_ongpu(l.outputs*l.batch / 2, output16, l.x_gpu); //copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1); @@ -744,21 +744,23 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state assert((l.nweights) > 0); cuda_convert_f32_to_f16(l.weight_updates_gpu, l.nweights, l.weight_updates_gpu16); - CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), - &one, - l.srcTensorDesc16, - input16, //state.input, - l.ddstTensorDesc16, - delta16, //l.delta_gpu, - l.convDesc, - l.bf_algo16, - state.workspace, - l.workspace_size, - &one, - l.dweightDesc16, - l.weight_updates_gpu16)); // l.weight_updates_gpu); + if (!state.net.adversarial) { + CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), + &one, + l.srcTensorDesc16, + input16, //state.input, + l.ddstTensorDesc16, + delta16, //l.delta_gpu, + l.convDesc, + l.bf_algo16, + state.workspace, + l.workspace_size, + &one, + l.dweightDesc16, + l.weight_updates_gpu16)); // l.weight_updates_gpu); - cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.nweights, l.weight_updates_gpu); + cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.nweights, l.weight_updates_gpu); + } if (state.delta) { if (l.binary || l.xnor) swap_binary(&l); @@ -794,21 +796,23 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state backward_batchnorm_layer_gpu(l, state); } - // calculate conv weight updates - // if used: beta=1 then loss decreases faster - CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), - &one, - l.srcTensorDesc, - state.input, - l.ddstTensorDesc, - l.delta_gpu, - l.convDesc, - l.bf_algo, - state.workspace, - l.workspace_size, - &one, - l.dweightDesc, - l.weight_updates_gpu)); + if (!state.net.adversarial) { + // calculate conv weight updates + // if used: beta=1 then loss decreases faster + CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), + &one, + l.srcTensorDesc, + state.input, + l.ddstTensorDesc, + l.delta_gpu, + l.convDesc, + l.bf_algo, + state.workspace, + l.workspace_size, + &one, + l.dweightDesc, + l.weight_updates_gpu)); + } if (state.delta) { if (l.binary || l.xnor) swap_binary(&l); diff --git a/src/detector.c b/src/detector.c index 75066ca6..d7e5e578 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1675,7 +1675,7 @@ void draw_object(char *datacfg, char *cfgfile, char *weightfile, char *filename, load_weights(&net, weightfile); } net.benchmark_layers = benchmark_layers; - fuse_conv_batchnorm(net); + //fuse_conv_batchnorm(net); //calculate_binary_weights(net); if (net.layers[net.n - 1].classes != names_size) { printf("\n Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n", diff --git a/src/image_opencv.cpp b/src/image_opencv.cpp index c65eef61..90bdb37e 100644 --- a/src/image_opencv.cpp +++ b/src/image_opencv.cpp @@ -1362,11 +1362,11 @@ extern "C" void cv_draw_object(image sized, float *truth_cpu, int max_boxes, int cv::setMouseCallback(window_name, callback_mouse_click); - int it_trackbar_value = 50; + int it_trackbar_value = 200; std::string const it_trackbar_name = "iterations"; int it_tb_res = cv::createTrackbar(it_trackbar_name, window_name, &it_trackbar_value, 1000); - int lr_trackbar_value = 12; + int lr_trackbar_value = 10; std::string const lr_trackbar_name = "learning_rate exp"; int lr_tb_res = cv::createTrackbar(lr_trackbar_name, window_name, &lr_trackbar_value, 20); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index ae1a5c71..15219011 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -333,6 +333,28 @@ void forward_backward_network_gpu(network net, float *x, float *y) float train_network_datum_gpu(network net, float *x, float *y) { *net.seen += net.batch; + if (net.adversarial_lr && rand_int(0, 1) == 1 && get_current_iteration(net) > net.burn_in) { + net.adversarial = 1; + float lr_old = net.learning_rate; + net.learning_rate = net.adversarial_lr; + layer l = net.layers[net.n - 1]; + float *truth_cpu = (float *)xcalloc(l.truths * l.batch, sizeof(float)); + + printf("\n adversarial training, adversarial_lr = %f \n", net.adversarial_lr); + + forward_backward_network_gpu(net, x, truth_cpu); + + image im; + im.w = net.w; + im.h = net.h; + im.c = net.c; + im.data = x; + //show_image(im, "adversarial data augmentation"); + + free(truth_cpu); + net.learning_rate = lr_old; + net.adversarial = 0; + } forward_backward_network_gpu(net, x, y); float error = get_network_cost(net); //if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net); diff --git a/src/parser.c b/src/parser.c index 41a9b688..b997e56f 100644 --- a/src/parser.c +++ b/src/parser.c @@ -1086,7 +1086,8 @@ void parse_net_options(list *options, network *net) net->letter_box = option_find_int_quiet(options, "letter_box", 0); net->label_smooth_eps = option_find_float_quiet(options, "label_smooth_eps", 0.0f); net->resize_step = option_find_float_quiet(options, "resize_step", 32); - net->adversarial = option_find_int_quiet(options, "adversarial", 0); + //net->adversarial = option_find_int_quiet(options, "adversarial", 0); + net->adversarial_lr = option_find_float_quiet(options, "adversarial_lr", 0); net->angle = option_find_float_quiet(options, "angle", 0); net->aspect = option_find_float_quiet(options, "aspect", 1);