Averaging Mean/Variance for several mini_batches inside 1 batch. Path to CBN.

pull/4907/head
AlexeyAB 5 years ago
parent f106f07d0d
commit 64fb042c63
  1. 2
      include/darknet.h
  2. 30
      src/batchnorm_layer.c
  3. 7
      src/blas.h
  4. 122
      src/blas_kernels.cu
  5. 6
      src/convolutional_kernels.cu
  6. 10
      src/convolutional_layer.c
  7. 2
      src/layer.c
  8. 4
      src/network.c
  9. 2
      src/parser.c

@ -551,6 +551,8 @@ struct layer {
float * mean_gpu;
float * variance_gpu;
float * m_cbn_avg_gpu;
float * v_cbn_avg_gpu;
float * rolling_mean_gpu;
float * rolling_variance_gpu;

@ -251,6 +251,33 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
if (state.train) {
simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_gpu);
// cbn
if (l.batch_normalize == 2) {
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
//fast_v_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.v_cbn_gpu);
int minibatch_index = state.net.current_subdivision + 1;
float alpha = 0.01;
int inverse_variance = 0;
#ifdef CUDNN
inverse_variance = 1;
#endif // CUDNN
fast_v_cbn_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, minibatch_index, l.m_cbn_avg_gpu, l.v_cbn_avg_gpu, l.variance_gpu,
alpha, l.rolling_mean_gpu, l.rolling_variance_gpu, inverse_variance, .00001);
normalize_scale_bias_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.scales_gpu, l.biases_gpu, l.batch, l.out_c, l.out_h*l.out_w, .00001f);
#ifndef CUDNN
simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_norm_gpu);
#endif // CUDNN
//printf("\n CBN \n");
}
else {
#ifdef CUDNN
float one = 1;
float zero = 0;
@ -280,6 +307,8 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
fix_nan_and_inf(l.rolling_mean_gpu, l.n);
fix_nan_and_inf(l.rolling_variance_gpu, l.n);
}
//simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_norm_gpu);
#else // CUDNN
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
@ -297,6 +326,7 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
#endif // CUDNN
}
}
else {
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);

@ -85,14 +85,17 @@ void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_del
void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta);
void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta);
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon);
void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, float epsilon);
void compare_2_arrays_gpu(float *one, float *two, int size);
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalizion);
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in,
float *weights, float *weight_updates, int nweights, float *in, float **layers_output, WEIGHTS_NORMALIZATION_T weights_normalizion);
void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
void add_bias_gpu(float *output, float *biases, int batch, int n, int size);

@ -10,6 +10,24 @@
#include "tree.h"
__global__ void compare_2_arrays_kernel(float *one, float *two, int size)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index >= size) return;
if (one[index] != two[index]) printf(" i: %d - one = %f, two = %f \n", index, one[index], two[index]);
}
void compare_2_arrays_gpu(float *one, float *two, int size)
{
const int num_blocks = get_number_of_blocks(size, BLOCK);
compare_2_arrays_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(one, two, size);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUDA(cudaDeviceSynchronize());
}
__global__ void scale_bias_kernel(float *output, float *scale, int batch, int filters, int spatial, int current_size)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
@ -179,7 +197,7 @@ __global__ void normalize_kernel(int N, float *x, float *mean, float *variance,
if (index >= N) return;
int f = (index / spatial) % filters;
x[index] = (x[index] - mean[f]) / (sqrtf(variance[f] + .000001f));
x[index] = (x[index] - mean[f]) / (sqrtf(variance[f] + .00001f));
}
extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
@ -470,8 +488,6 @@ __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
}
__global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
{
const int threads = BLOCK;
@ -492,14 +508,21 @@ __global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial,
__syncthreads();
if(id == 0){
mean[filter] = 0;
float mean_tmp = 0;
for(i = 0; i < threads; ++i){
mean[filter] += local[i];
mean_tmp += local[i];
}
mean[filter] /= spatial * batch;
mean_tmp /= spatial * batch;
mean[filter] = mean_tmp;
}
}
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
fast_mean_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, batch, filters, spatial, mean);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void fast_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
const int threads = BLOCK;
@ -521,27 +544,100 @@ __global__ void fast_variance_kernel(float *x, float *mean, int batch, int filt
__syncthreads();
if(id == 0){
variance[filter] = 0;
float variance_tmp = 0;
for(i = 0; i < threads; ++i){
variance[filter] += local[i];
variance_tmp += local[i];
}
variance[filter] /= (spatial * batch - 1);
variance_tmp /= (spatial * batch);// -1);
variance[filter] = variance_tmp;
}
}
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
fast_mean_kernel<<<filters, BLOCK, 0, get_cuda_stream()>>>(x, batch, filters, spatial, mean);
fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
__global__ void fast_v_cbn_kernel(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
{
fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
const int threads = BLOCK;
__shared__ float local[threads];
int id = threadIdx.x;
local[id] = 0;
int filter = blockIdx.x;
int i, j;
for (j = 0; j < batch; ++j) {
for (i = 0; i < spatial; i += threads) {
int index = j*spatial*filters + filter*spatial + i + id;
local[id] += (i + id < spatial) ? powf(x[index], 2) : 0;
}
}
__syncthreads();
if (id == 0) {
float v_tmp = 0;
v_tmp = 0;
for (i = 0; i < threads; ++i) {
v_tmp += local[i];
}
v_tmp /= (spatial * batch - 1);
v_tmp = fmax(v_tmp, powf(mean[filter], 2));
const float alpha_cbn = 1.0f / minibatch_index;
m_avg[filter] = alpha_cbn * mean[filter] + (1 - alpha_cbn) * m_avg[filter];
mean[filter] = m_avg[filter];
v_avg[filter] = alpha_cbn * v_tmp + (1 - alpha_cbn) * v_avg[filter];
float variance_tmp = fmax(0.0f, v_avg[filter] - powf(m_avg[filter], 2));
if (inverse_variance) variance_tmp = 1.0f / sqrtf(variance_tmp + epsilon);
variance[filter] = variance_tmp;
rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter];
rolling_variance_gpu[filter] = alpha * variance[filter] + (1 - alpha) * rolling_variance_gpu[filter];
}
}
extern "C" void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
{
fast_v_cbn_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, mean, batch, filters, spatial, minibatch_index, m_avg, v_avg, variance, alpha, rolling_mean_gpu, rolling_variance_gpu, inverse_variance, epsilon);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void normalize_scale_bias_kernel(int N, float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, float epsilon)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index >= N) return;
int f = (index / spatial) % filters;
float val = (x[index] - mean[f]) / (sqrtf(variance[f] + epsilon)) * scales[f] + biases[f];
if (!isnan(val) && !isinf(val))
x[index] = val;
}
extern "C" void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, float epsilon)
{
const int current_size = batch * filters * spatial;
const int num_blocks = get_number_of_blocks(current_size, BLOCK);
normalize_scale_bias_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(current_size, x, mean, variance, scales, biases, batch, filters, spatial, epsilon);
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
mean_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, batch, filters, spatial, mean);

@ -578,7 +578,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.workspace); // output
@ -858,7 +858,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.workspace); // output
@ -883,7 +883,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding size (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding size (h, w)
l.stride_y, l.stride_x, // stride size (h, w)
l.dilation, l.dilation, // dilation size (h, w)
delta); // output (delta)

@ -416,7 +416,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
if (l.share_layer) {
if (l.size != l.share_layer->size || l.nweights != l.share_layer->nweights || l.c != l.share_layer->c || l.n != l.share_layer->n) {
printf("Layer size, nweights, channels or filters don't match for the share_layer");
printf(" Layer size, nweights, channels or filters don't match for the share_layer");
getchar();
}
@ -611,6 +611,8 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.mean_gpu = cuda_make_array(l.mean, n);
l.variance_gpu = cuda_make_array(l.variance, n);
l.m_cbn_avg_gpu = cuda_make_array(l.mean, n);
l.v_cbn_avg_gpu = cuda_make_array(l.variance, n);
#ifndef CUDNN
l.mean_delta_gpu = cuda_make_array(l.mean, n);
l.variance_delta_gpu = cuda_make_array(l.variance, n);
@ -1238,7 +1240,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
b); // output
@ -1429,7 +1431,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
b); // output
@ -1451,7 +1453,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
l.c / l.groups, // input channels (h, w)
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.pad * l.dilation, l.pad * l.dilation, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.delta + (i*l.groups + j)* (l.c / l.groups)*l.h*l.w); // output (delta)

@ -158,6 +158,8 @@ void free_layer_custom(layer l, int keep_cudnn_desc)
if (l.binary_weights_gpu) cuda_free(l.binary_weights_gpu);
if (l.mean_gpu) cuda_free(l.mean_gpu), l.mean_gpu = NULL;
if (l.variance_gpu) cuda_free(l.variance_gpu), l.variance_gpu = NULL;
if (l.m_cbn_avg_gpu) cuda_free(l.m_cbn_avg_gpu), l.m_cbn_avg_gpu = NULL;
if (l.v_cbn_avg_gpu) cuda_free(l.v_cbn_avg_gpu), l.v_cbn_avg_gpu = NULL;
if (l.rolling_mean_gpu) cuda_free(l.rolling_mean_gpu), l.rolling_mean_gpu = NULL;
if (l.rolling_variance_gpu) cuda_free(l.rolling_variance_gpu), l.rolling_variance_gpu = NULL;
if (l.variance_delta_gpu) cuda_free(l.variance_delta_gpu), l.variance_delta_gpu = NULL;

@ -1089,14 +1089,14 @@ void fuse_conv_batchnorm(network net)
int f;
for (f = 0; f < l->n; ++f)
{
l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f] + .000001));
l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f] + .00001));
const size_t filter_size = l->size*l->size*l->c / l->groups;
int i;
for (i = 0; i < filter_size; ++i) {
int w_index = f*filter_size + i;
l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f] + .000001));
l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f] + .00001));
}
}

@ -198,6 +198,8 @@ convolutional_layer parse_convolutional(list *options, size_params params)
batch=params.batch;
if(!(h && w && c)) error("Layer before convolutional layer must output image.");
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int cbn = option_find_int_quiet(options, "cbn", 0);
if (cbn) batch_normalize = 2;
int binary = option_find_int_quiet(options, "binary", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);

Loading…
Cancel
Save