Added DropBlock. Use [dropout] dropblock=1 dropblock_size=0.5 probability=0.1

pull/4540/head
AlexeyAB 6 years ago
parent a616fcd28e
commit 1df3ddc7d6
  1. 2
      include/darknet.h
  2. 19
      src/dropout_layer.c
  3. 2
      src/dropout_layer.h
  4. 102
      src/dropout_layer_kernels.cu
  5. 6
      src/gaussian_yolo_layer.c
  6. 6
      src/parser.c
  7. 2
      src/yolo_layer.c

@ -311,6 +311,8 @@ struct layer {
float temperature;
float probability;
float dropblock_size;
int dropblock;
float scale;
char * cweights;

@ -4,16 +4,28 @@
#include <stdlib.h>
#include <stdio.h>
dropout_layer make_dropout_layer(int batch, int inputs, float probability)
dropout_layer make_dropout_layer(int batch, int inputs, float probability, int dropblock, float dropblock_size, int w, int h, int c)
{
dropout_layer l = { (LAYER_TYPE)0 };
l.type = DROPOUT;
l.probability = probability;
l.dropblock = dropblock;
l.dropblock_size = dropblock_size;
if (l.dropblock) {
l.out_w = l.w = w;
l.out_h = l.h = h;
l.out_c = l.c = c;
if (l.w <= 0 || l.h <= 0 || l.c <= 0) {
printf(" Error: DropBlock - there must be positive values for: l.w=%d, l.h=%d, l.c=%d \n", l.w, l.h, l.c);
exit(0);
}
}
l.inputs = inputs;
l.outputs = inputs;
l.batch = batch;
l.rand = (float*)calloc(inputs * batch, sizeof(float));
l.scale = 1./(1.-probability);
l.scale = 1./(1.0 - probability);
l.forward = forward_dropout_layer;
l.backward = backward_dropout_layer;
#ifdef GPU
@ -21,7 +33,8 @@ dropout_layer make_dropout_layer(int batch, int inputs, float probability)
l.backward_gpu = backward_dropout_layer_gpu;
l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
#endif
fprintf(stderr, "dropout p = %.2f %4d -> %4d\n", probability, inputs, inputs);
if(l.dropblock) fprintf(stderr, "dropblock p = %.2f block_size = %.2f %4d -> %4d\n", probability, l.dropblock_size, inputs, inputs);
else fprintf(stderr, "dropout p = %.2f %4d -> %4d\n", probability, inputs, inputs);
return l;
}

@ -9,7 +9,7 @@ typedef layer dropout_layer;
#ifdef __cplusplus
extern "C" {
#endif
dropout_layer make_dropout_layer(int batch, int inputs, float probability);
dropout_layer make_dropout_layer(int batch, int inputs, float probability, int dropblock, float dropblock_size, int w, int h, int c);
void forward_dropout_layer(dropout_layer l, network_state state);
void backward_dropout_layer(dropout_layer l, network_state state);

@ -1,6 +1,7 @@
#include <cuda_runtime.h>
#include <curand.h>
#include <cublas_v2.h>
#include <cstring>
#include "dropout_layer.h"
#include "dark_cuda.h"
@ -12,35 +13,104 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
if(id < size) input[id] = (rand[id] < prob) ? 0 : input[id]*scale;
}
void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
__global__ void drop_block_kernel(float *input, int size, float *mask, float scale)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id < size) input[id] = (mask[id]) ? 0 : (input[id] * scale);
}
void forward_dropout_layer_gpu(dropout_layer l, network_state state)
{
if (!state.train) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;
// dropblock
if (l.dropblock) {
//l.probability = 1 / keep_prob
const int max_blocks_per_channel = 3;
const float cur_prob = l.probability * (iteration_num / (float)state.net.max_batches);
const int block_width = l.dropblock_size * l.w;
const int block_height = l.dropblock_size * l.h;
const float prob_place_block = cur_prob / (l.dropblock_size * l.dropblock_size * max_blocks_per_channel);
memset(l.rand, 0, l.batch * l.outputs * sizeof(float));
float count_ones = 0;
int b, k, x, y, i;
for (b = 0; b < l.batch; b++) {
for (k = 0; k < l.c; k++) {
for (i = 0; i < max_blocks_per_channel; i++) {
float rnd = random_float();
//printf(" rnd = %f \n", rnd);
if (rnd < prob_place_block) {
//count_ones += block_width *block_height;
const int pre_index = k*l.w*l.h + b*l.w*l.h*l.c;
const int x_block = rand_int(0, l.w - block_width - 1);
const int y_block = rand_int(0, l.h - block_height - 1);
for (y = y_block; y < (y_block + block_height); y++) {
for (x = x_block; x < (x_block + block_width); x++) {
const int index = x + y*l.w + pre_index;
l.rand[index] = 1;
}
}
}
}
}
}
for (i = 0; i < (l.batch*l.outputs); ++i) if (l.rand[i]) count_ones++;
cuda_push_array(l.rand_gpu, l.rand, l.batch*l.outputs);
l.scale = (float)(l.batch*l.outputs) / (l.batch*l.outputs - count_ones);
//printf("\n l.scale = %f, cur_prob = %f, count_ones = %f, prob_place_block = %f, \n",
// l.scale, cur_prob, count_ones, prob_place_block);
int size = l.inputs*l.batch;
drop_block_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (state.input, size, l.rand_gpu, l.scale);
CHECK_CUDA(cudaPeekAtLastError());
int size = layer.inputs*layer.batch;
cuda_random(layer.rand_gpu, size);
/*
int i;
for(i = 0; i < size; ++i){
layer.rand[i] = rand_uniform();
}
cuda_push_array(layer.rand_gpu, layer.rand, size);
*/
// dropout
else {
int size = l.inputs*l.batch;
cuda_random(l.rand_gpu, size);
/*
int i;
for(i = 0; i < size; ++i){
layer.rand[i] = rand_uniform();
}
cuda_push_array(layer.rand_gpu, layer.rand, size);
*/
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
CHECK_CUDA(cudaPeekAtLastError());
yoloswag420blazeit360noscope << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (state.input, size, l.rand_gpu, l.probability, l.scale);
CHECK_CUDA(cudaPeekAtLastError());
}
}
void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
void backward_dropout_layer_gpu(dropout_layer l, network_state state)
{
if(!state.delta) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;
int size = layer.inputs*layer.batch;
int size = l.inputs*l.batch;
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
CHECK_CUDA(cudaPeekAtLastError());
// dropblock
if (l.dropblock) {
drop_block_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (state.delta, size, l.rand_gpu, l.scale);
CHECK_CUDA(cudaPeekAtLastError());
}
// dropout
else {
yoloswag420blazeit360noscope << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (state.delta, size, l.rand_gpu, l.probability, l.scale);
CHECK_CUDA(cudaPeekAtLastError());
}
}

@ -465,12 +465,10 @@ void forward_gaussian_yolo_layer(const layer l, network_state state)
if(!truth.x) break;
float best_iou = 0;
int best_n = 0;
//i = (truth.x * l.w);
//j = (truth.y * l.h);
i = (truth.x * l.w);
j = (truth.y * l.h);
if (l.yolo_point == YOLO_CENTER) {
i = (truth.x * l.w);
j = (truth.y * l.h);
}
else if (l.yolo_point == YOLO_LEFT_TOP) {
i = min_val_cmp(l.w-1, max_val_cmp(0, ((truth.x - truth.w / 2) * l.w)));

@ -701,8 +701,10 @@ avgpool_layer parse_avgpool(list *options, size_params params)
dropout_layer parse_dropout(list *options, size_params params)
{
float probability = option_find_float(options, "probability", .5);
dropout_layer layer = make_dropout_layer(params.batch, params.inputs, probability);
float probability = option_find_float(options, "probability", .2);
int dropblock = option_find_int_quiet(options, "dropblock", 0);
float dropblock_size = option_find_float_quiet(options, "dropblock_size", 0.5);
dropout_layer layer = make_dropout_layer(params.batch, params.inputs, probability, dropblock, dropblock_size, params.w, params.h, params.c);
layer.out_w = params.w;
layer.out_h = params.h;
layer.out_c = params.c;

@ -732,7 +732,7 @@ void forward_yolo_layer_gpu(const layer l, network_state state)
// if(y->1) x -> inf
// if(y->0) x -> -inf
activate_array_ongpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC); // x,y
scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1); // scale x,y
if (l.scale_x_y != 1) scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1); // scale x,y
index = entry_index(l, b, n*l.w*l.h, 4);
activate_array_ongpu(l.output_gpu + index, (1+l.classes)*l.w*l.h, LOGISTIC); // classes and objectness
}

Loading…
Cancel
Save