Added object Detection & Tracking using conv-rnn layer on frames from video

pull/2514/head
AlexeyAB 6 years ago
parent 50956447f8
commit 75f2a3e7cf
  1. 218
      build/darknet/x64/cfg/yolov3-tiny_occlusion_track.cfg
  2. 218
      cfg/yolov3-tiny_occlusion_track.cfg
  3. 8
      include/darknet.h
  4. 14
      src/batchnorm_layer.c
  5. 23
      src/convolutional_kernels.cu
  6. 1
      src/convolutional_layer.c
  7. 104
      src/crnn_layer.c
  8. 3
      src/crnn_layer.h
  9. 171
      src/data.c
  10. 2
      src/data.h
  11. 56
      src/detector.c
  12. 182
      src/layer.c
  13. 109
      src/network.c
  14. 1
      src/network.h
  15. 11
      src/network_kernels.cu
  16. 28
      src/parser.c

@ -0,0 +1,218 @@
[net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=8
subdivisions=4
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
track=1
time_steps=20
augment_speed=3
learning_rate=0.001
burn_in=1000
max_batches = 10000
policy=steps
steps=9000,9500
scales=.1,.1
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=1
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
###########
[crnn]
batch_normalize=1
size=3
pad=1
output=512
hidden=256
activation=leaky
#[shortcut]
#from=-2
#activation=linear
###########
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=0
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 8
[crnn]
batch_normalize=1
size=3
pad=1
output=256
hidden=128
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=0

@ -0,0 +1,218 @@
[net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=8
subdivisions=4
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
track=1
time_steps=20
augment_speed=3
learning_rate=0.001
burn_in=1000
max_batches = 10000
policy=steps
steps=9000,9500
scales=.1,.1
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=1
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
###########
[crnn]
batch_normalize=1
size=3
pad=1
output=512
hidden=256
activation=leaky
#[shortcut]
#from=-2
#activation=linear
###########
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=0
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 8
[crnn]
batch_normalize=1
size=3
pad=1
output=256
hidden=128
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=0

@ -570,7 +570,9 @@ typedef struct network {
float saturation;
float hue;
int random;
int small_object;
int track;
int augment_speed;
int try_fix_nan;
int gpu_index;
tree *hierarchy;
@ -698,7 +700,9 @@ typedef struct load_args {
int scale;
int center;
int coords;
int small_object;
int mini_batch;
int track;
int augment_speed;
float jitter;
int flip;
float angle;

@ -205,6 +205,15 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
.00001,
l.mean_gpu, // output (should be FP32)
l.variance_gpu); // output (should be FP32)
if (state.net.try_fix_nan) {
fix_nan_and_inf(l.scales_gpu, l.n);
fix_nan_and_inf(l.biases_gpu, l.n);
fix_nan_and_inf(l.mean_gpu, l.n);
fix_nan_and_inf(l.variance_gpu, l.n);
fix_nan_and_inf(l.rolling_mean_gpu, l.n);
fix_nan_and_inf(l.rolling_variance_gpu, l.n);
}
#else
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
@ -272,5 +281,10 @@ void backward_batchnorm_layer_gpu(layer l, network_state state)
#endif
if (l.type == BATCHNORM) simple_copy_ongpu(l.outputs*l.batch, l.delta_gpu, state.delta);
//copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
if (state.net.try_fix_nan) {
fix_nan_and_inf(l.scale_updates_gpu, l.n);
fix_nan_and_inf(l.bias_updates_gpu, l.n);
}
}
#endif

@ -519,6 +519,15 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
else {
//#else
/*
int input_nan_inf = is_nan_or_inf(state.input, l.inputs * l.batch);
printf("\n is_nan_or_inf(state.input) = %d \n", input_nan_inf);
if (input_nan_inf) getchar();
int weights_nan_inf = is_nan_or_inf(l.weights_gpu, l.size * l.size * l.c * l.n);
printf("\n is_nan_or_inf(l.weights_gpu) = %d \n", weights_nan_inf);
if (weights_nan_inf) getchar();
*/
CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(),
&alpha, //&one,
@ -581,6 +590,10 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//if(l.dot > 0) dot_error_gpu(l);
if(l.binary || l.xnor) swap_binary(&l);
//cudaDeviceSynchronize(); // for correct profiling of performance
if (state.net.try_fix_nan) {
fix_nan_and_inf(l.output_gpu, l.outputs*l.batch);
}
}
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
@ -609,7 +622,6 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
l.c % 8 == 0 && l.n % 8 == 0)
{
const size_t input16_size = l.batch*l.c*l.w*l.h;
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
@ -762,6 +774,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
&one,
l.dsrcTensorDesc,
state.delta));
if (l.binary || l.xnor) swap_binary(&l);
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
}
@ -803,6 +816,14 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
}
}
#endif
if (state.net.try_fix_nan) {
if (state.delta) {
fix_nan_and_inf(state.delta, l.inputs * l.batch);
}
int size = l.size * l.size * l.c * l.n;
fix_nan_and_inf(l.weight_updates_gpu, size);
fix_nan_and_inf(l.weights_gpu, size);
}
}
void pull_convolutional_layer(convolutional_layer layer)

@ -354,6 +354,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.pad = padding;
l.batch_normalize = batch_normalize;
l.learning_rate_scale = 1;
l.nweights = l.c*l.n*l.size*l.size;
l.weights = (float*)calloc(c * n * size * size, sizeof(float));
l.weight_updates = (float*)calloc(c * n * size * size, sizeof(float));

@ -26,7 +26,7 @@ static void increment_layer(layer *l, int steps)
#endif
}
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize)
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor)
{
fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters);
batch = batch / steps;
@ -34,33 +34,41 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.batch = batch;
l.type = CRNN;
l.steps = steps;
l.size = size;
l.stride = stride;
l.pad = pad;
l.h = h;
l.w = w;
l.c = c;
l.out_h = h;
l.out_w = w;
l.out_c = output_filters;
l.inputs = h*w*c;
l.inputs = h * w * c;
l.hidden = h * w * hidden_filters;
l.outputs = l.out_h * l.out_w * l.out_c;
l.xnor = xnor;
l.state = (float*)calloc(l.hidden * batch * (steps + 1), sizeof(float));
l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float));
l.input_layer = (layer*)malloc(sizeof(layer));
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, size, stride, pad, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = (layer*)malloc(sizeof(layer));
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, size, stride, pad, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = (layer*)malloc(sizeof(layer));
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, size, stride, pad, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;
l.out_h = l.output_layer->out_h;
l.out_w = l.output_layer->out_w;
l.outputs = l.output_layer->outputs;
assert(l.input_layer->outputs == l.self_layer->outputs);
assert(l.input_layer->outputs == l.output_layer->inputs);
l.output = l.output_layer->output;
l.delta = l.output_layer->delta;
@ -72,7 +80,7 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.forward_gpu = forward_crnn_layer_gpu;
l.backward_gpu = backward_crnn_layer_gpu;
l.update_gpu = update_crnn_layer_gpu;
l.state_gpu = cuda_make_array(l.state, batch*l.hidden*(steps + 1));
l.state_gpu = cuda_make_array(l.state, l.batch*l.hidden*(l.steps + 1));
l.output_gpu = l.output_layer->output_gpu;
l.delta_gpu = l.output_layer->delta_gpu;
#endif
@ -80,6 +88,46 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
return l;
}
void resize_crnn_layer(layer *l, int w, int h)
{
resize_convolutional_layer(l->input_layer, w, h);
if (l->workspace_size < l->input_layer->workspace_size) l->workspace_size = l->input_layer->workspace_size;
resize_convolutional_layer(l->self_layer, w, h);
if (l->workspace_size < l->self_layer->workspace_size) l->workspace_size = l->self_layer->workspace_size;
resize_convolutional_layer(l->output_layer, w, h);
if (l->workspace_size < l->output_layer->workspace_size) l->workspace_size = l->output_layer->workspace_size;
l->output = l->output_layer->output;
l->delta = l->output_layer->delta;
int hidden_filters = l->self_layer->c;
l->w = w;
l->h = h;
l->inputs = h * w * l->c;
l->hidden = h * w * hidden_filters;
l->out_h = l->output_layer->out_h;
l->out_w = l->output_layer->out_w;
l->outputs = l->output_layer->outputs;
assert(l->input_layer->inputs == l->inputs);
assert(l->self_layer->inputs == l->hidden);
assert(l->input_layer->outputs == l->self_layer->outputs);
assert(l->input_layer->outputs == l->output_layer->inputs);
l->state = (float*)realloc(l->state, l->batch*l->hidden*(l->steps + 1)*sizeof(float));
#ifdef GPU
if (l->state_gpu) cudaFree(l->state_gpu);
l->state_gpu = cuda_make_array(l->state, l->batch*l->hidden*(l->steps + 1));
l->output_gpu = l->output_layer->output_gpu;
l->delta_gpu = l->output_layer->delta_gpu;
#endif
}
void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay)
{
update_convolutional_layer(*(l.input_layer), batch, learning_rate, momentum, decay);
@ -92,15 +140,19 @@ void forward_crnn_layer(layer l, network_state state)
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
s.net = state.net;
//s.index = state.index;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
layer output_layer = *(l.output_layer);
fill_cpu(l.outputs * l.batch * l.steps, 0, output_layer.delta, 1);
fill_cpu(l.hidden * l.batch * l.steps, 0, self_layer.delta, 1);
fill_cpu(l.hidden * l.batch * l.steps, 0, input_layer.delta, 1);
if(state.train) fill_cpu(l.hidden * l.batch, 0, l.state, 1);
if (state.train) {
fill_cpu(l.outputs * l.batch * l.steps, 0, output_layer.delta, 1);
fill_cpu(l.hidden * l.batch * l.steps, 0, self_layer.delta, 1);
fill_cpu(l.hidden * l.batch * l.steps, 0, input_layer.delta, 1);
fill_cpu(l.hidden * l.batch, 0, l.state, 1);
}
for (i = 0; i < l.steps; ++i) {
s.input = state.input;
@ -134,6 +186,8 @@ void backward_crnn_layer(layer l, network_state state)
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
s.net = state.net;
//s.index = state.index;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -208,6 +262,8 @@ void forward_crnn_layer_gpu(layer l, network_state state)
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
s.net = state.net;
if(!state.train) s.index = state.index; // use TC only for detection
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -223,10 +279,12 @@ void forward_crnn_layer_gpu(layer l, network_state state)
#endif //CUDNN_HALF
*/
fill_ongpu(l.outputs * l.batch * l.steps, 0, output_layer.delta_gpu, 1);
fill_ongpu(l.hidden * l.batch * l.steps, 0, self_layer.delta_gpu, 1);
fill_ongpu(l.hidden * l.batch * l.steps, 0, input_layer.delta_gpu, 1);
if(state.train) fill_ongpu(l.hidden * l.batch, 0, l.state_gpu, 1);
if (state.train) {
fill_ongpu(l.outputs * l.batch * l.steps, 0, output_layer.delta_gpu, 1);
fill_ongpu(l.hidden * l.batch * l.steps, 0, self_layer.delta_gpu, 1);
fill_ongpu(l.hidden * l.batch * l.steps, 0, input_layer.delta_gpu, 1);
fill_ongpu(l.hidden * l.batch, 0, l.state_gpu, 1);
}
for (i = 0; i < l.steps; ++i) {
s.input = state.input;
@ -260,6 +318,8 @@ void backward_crnn_layer_gpu(layer l, network_state state)
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
s.net = state.net;
//s.index = state.index;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -267,6 +327,7 @@ void backward_crnn_layer_gpu(layer l, network_state state)
increment_layer(&input_layer, l.steps - 1);
increment_layer(&self_layer, l.steps - 1);
increment_layer(&output_layer, l.steps - 1);
float *init_state_gpu = l.state_gpu;
l.state_gpu += l.hidden*l.batch*l.steps;
for (i = l.steps-1; i >= 0; --i) {
//copy_ongpu(l.hidden * l.batch, input_layer.output_gpu, 1, l.state_gpu, 1); // commented in RNN
@ -291,9 +352,16 @@ void backward_crnn_layer_gpu(layer l, network_state state)
else s.delta = 0;
backward_convolutional_layer_gpu(input_layer, s);
if (state.net.try_fix_nan) {
fix_nan_and_inf(output_layer.delta_gpu, output_layer.inputs * output_layer.batch);
fix_nan_and_inf(self_layer.delta_gpu, self_layer.inputs * self_layer.batch);
fix_nan_and_inf(input_layer.delta_gpu, input_layer.inputs * input_layer.batch);
}
increment_layer(&input_layer, -1);
increment_layer(&self_layer, -1);
increment_layer(&output_layer, -1);
}
fill_ongpu(l.hidden * l.batch, 0, init_state_gpu, 1); //clean l.state_gpu
}
#endif

@ -9,7 +9,8 @@
#ifdef __cplusplus
extern "C" {
#endif
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize);
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor);
void resize_crnn_layer(layer *l, int w, int h);
void forward_crnn_layer(layer l, network_state state);
void backward_crnn_layer(layer l, network_state state);

@ -41,6 +41,38 @@ char **get_random_paths_indexes(char **paths, int n, int m, int *indexes)
}
*/
char **get_sequential_paths(char **paths, int n, int m, int mini_batch, int augment_speed)
{
int speed = rand_int(1, augment_speed);
if (speed < 1) speed = 1;
char** sequentia_paths = (char**)calloc(n, sizeof(char*));
int i;
pthread_mutex_lock(&mutex);
//printf("n = %d, mini_batch = %d \n", n, mini_batch);
unsigned int *start_time_indexes = (unsigned int *)calloc(mini_batch, sizeof(unsigned int));
for (i = 0; i < mini_batch; ++i) {
start_time_indexes[i] = random_gen() % m;
//printf(" start_time_indexes[i] = %u, ", start_time_indexes[i]);
}
for (i = 0; i < n; ++i) {
do {
int time_line_index = i % mini_batch;
unsigned int index = start_time_indexes[time_line_index] % m;
start_time_indexes[time_line_index] += speed;
//int index = random_gen() % m;
sequentia_paths[i] = paths[index];
//if(i == 0) printf("%s\n", paths[index]);
//printf(" index = %u - grp: %s \n", index, paths[index]);
if (strlen(sequentia_paths[i]) <= 4) printf(" Very small path to the image: %s \n", sequentia_paths[i]);
} while (strlen(sequentia_paths[i]) == 0);
}
//free(start_time_indexes);
pthread_mutex_unlock(&mutex);
return sequentia_paths;
}
char **get_random_paths(char **paths, int n, int m)
{
char** random_paths = (char**)calloc(n, sizeof(char*));
@ -303,7 +335,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int
}
void fill_truth_detection(const char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy,
int small_object, int net_w, int net_h)
int net_w, int net_h)
{
char labelpath[4096];
replace_image_to_label(path, labelpath);
@ -313,12 +345,6 @@ void fill_truth_detection(const char *path, int num_boxes, float *truth, int cla
box_label *boxes = read_boxes(labelpath, &count);
float lowest_w = 1.F / net_w;
float lowest_h = 1.F / net_h;
if (small_object == 1) {
for (i = 0; i < count; ++i) {
if (boxes[i].w < lowest_w) boxes[i].w = lowest_w;
if (boxes[i].h < lowest_h) boxes[i].h = lowest_h;
}
}
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
if (count > num_boxes) count = num_boxes;
@ -729,6 +755,16 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
return d;
}
static box float_to_box_stride(float *f, int stride)
{
box b = { 0 };
b.x = f[0];
b.y = f[1 * stride];
b.w = f[2 * stride];
b.h = f[3 * stride];
return b;
}
#ifdef OPENCV
#include <opencv2/highgui/highgui_c.h>
#include <opencv2/imgproc/imgproc_c.h>
@ -740,10 +776,12 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
#include "http_stream.h"
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed)
{
c = c ? c : 3;
char **random_paths = get_random_paths(paths, n, m);
char **random_paths;
if (track) random_paths = get_sequential_paths(paths, n, m, mini_batch, augment_speed);
else random_paths = get_random_paths(paths, n, m);
int i;
data d = {0};
d.shallow = 0;
@ -752,6 +790,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*c;
float r1, r2, r3, r4;
float dhue, dsat, dexp, flip;
int augmentation_calculated = 0;
d.y = make_matrix(n, 5*boxes);
for(i = 0; i < n; ++i){
const char *filename = random_paths[i];
@ -775,10 +817,25 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
int dw = (ow*jitter);
int dh = (oh*jitter);
int pleft = rand_uniform_strong(-dw, dw);
int pright = rand_uniform_strong(-dw, dw);
int ptop = rand_uniform_strong(-dh, dh);
int pbot = rand_uniform_strong(-dh, dh);
if(!augmentation_calculated || !track)
{
augmentation_calculated = 1;
r1 = random_float();
r2 = random_float();
r3 = random_float();
r4 = random_float();
dhue = rand_uniform_strong(-hue, hue);
dsat = rand_scale(saturation);
dexp = rand_scale(exposure);
flip = use_flip ? random_gen() % 2 : 0;
}
int pleft = rand_precalc_random(-dw, dw, r1);
int pright = rand_precalc_random(-dw, dw, r2);
int ptop = rand_precalc_random(-dh, dh, r3);
int pbot = rand_precalc_random(-dh, dh, r4);
int swidth = ow - pleft - pright;
int sheight = oh - ptop - pbot;
@ -786,22 +843,32 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
float sx = (float)swidth / ow;
float sy = (float)sheight / oh;
int flip = use_flip ? random_gen()%2 : 0;
float dx = ((float)pleft/ow)/sx;
float dy = ((float)ptop /oh)/sy;
float dhue = rand_uniform_strong(-hue, hue);
float dsat = rand_scale(saturation);
float dexp = rand_scale(exposure);
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp);
d.X.vals[i] = ai.data;
//show_image(ai, "aug");
//cvWaitKey(0);
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h);
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, small_object, w, h);
/*
{
char buff[10];
sprintf(buff, "aug_%s_%d", random_paths[i], random_gen());
int t;
for (t = 0; t < boxes; ++t) {
box b = float_to_box_stride(d.y.vals[i] + t*(4 + 1), 1);
if (!b.x) break;
int left = (b.x - b.w / 2.)*ai.w;
int right = (b.x + b.w / 2.)*ai.w;
int top = (b.y - b.h / 2.)*ai.h;
int bot = (b.y + b.h / 2.)*ai.h;
draw_box_width(ai, left, top, right, bot, 3, 150, 100, 50); // 3 channels RGB
}
show_image(ai, buff);
cvWaitKey(0);
}*/
cvReleaseImage(&src);
}
@ -809,10 +876,12 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
return d;
}
#else // OPENCV
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed)
{
c = c ? c : 3;
char **random_paths = get_random_paths(paths, n, m);
char **random_paths;
if(track) random_paths = get_sequential_paths(paths, n, m, mini_batch, augment_speed);
else random_paths = get_random_paths(paths, n, m);
int i;
data d = { 0 };
d.shallow = 0;
@ -821,6 +890,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*c;
float r1, r2, r3, r4;
float dhue, dsat, dexp, flip;
int augmentation_calculated = 0;
d.y = make_matrix(n, 5 * boxes);
for (i = 0; i < n; ++i) {
image orig = load_image(random_paths[i], 0, 0, c);
@ -831,10 +904,25 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
int dw = (ow*jitter);
int dh = (oh*jitter);
int pleft = rand_uniform_strong(-dw, dw);
int pright = rand_uniform_strong(-dw, dw);
int ptop = rand_uniform_strong(-dh, dh);
int pbot = rand_uniform_strong(-dh, dh);
if (!augmentation_calculated || !track)
{
augmentation_calculated = 1;
r1 = random_float();
r2 = random_float();
r3 = random_float();
r4 = random_float();
dhue = rand_uniform_strong(-hue, hue);
dsat = rand_scale(saturation);
dexp = rand_scale(exposure);
flip = use_flip ? random_gen() % 2 : 0;
}
int pleft = rand_precalc_random(-dw, dw, r1);
int pright = rand_precalc_random(-dw, dw, r2);
int ptop = rand_precalc_random(-dh, dh, r3);
int pbot = rand_precalc_random(-dh, dh, r4);
int swidth = ow - pleft - pright;
int sheight = oh - ptop - pbot;
@ -842,7 +930,6 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
float sx = (float)swidth / ow;
float sy = (float)sheight / oh;
int flip = use_flip ? random_gen() % 2 : 0;
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
float dx = ((float)pleft / ow) / sx;
@ -850,10 +937,30 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
image sized = resize_image(cropped, w, h);
if (flip) flip_image(sized);
random_distort_image(sized, hue, saturation, exposure);
distort_image(sized, dhue, dsat, dexp);
//random_distort_image(sized, hue, saturation, exposure);
d.X.vals[i] = sized.data;
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, small_object, w, h);
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
/*
{
char buff[10];
sprintf(buff, "aug_%s_%d", random_paths[i], random_gen());
int t;
for (t = 0; t < boxes; ++t) {
box b = float_to_box_stride(d.y.vals[i] + t*(4 + 1), 1);
if (!b.x) break;
int left = (b.x - b.w / 2.)*sized.w;
int right = (b.x + b.w / 2.)*sized.w;
int top = (b.y - b.h / 2.)*sized.h;
int bot = (b.y + b.h / 2.)*sized.h;
draw_box_width(sized, left, top, right, bot, 3, 150, 100, 50); // 3 channels RGB
}
show_image(sized, buff);
cvWaitKey(0);
}*/
free_image(orig);
free_image(cropped);
@ -883,7 +990,7 @@ void *load_thread(void *ptr)
} else if (a.type == REGION_DATA){
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.small_object);
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
} else if (a.type == COMPARE_DATA){

@ -86,7 +86,7 @@ void print_letters(float *pred, int n);
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object);
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed);
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_super(char **paths, int n, int m, int w, int h, int scale);

@ -76,11 +76,14 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
cuda_set_device(gpus[0]);
printf(" Prepare additional network for mAP calculation...\n");
net_map = parse_network_cfg_custom(cfgfile, 1, 0);
net_map = parse_network_cfg_custom(cfgfile, 1, 1);
int k; // free memory unnecessary arrays
for (k = 0; k < net_map.n; ++k) {
free_layer(net_map.layers[k]);
free_layer(net_map.layers[k]);
}
/*
#ifdef GPU
cuda_free(net_map.workspace);
cuda_free(net_map.input_state_gpu);
@ -89,6 +92,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
#else
free(net_map.workspace);
#endif
*/
}
srand(time(0));
@ -156,7 +160,6 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
args.flip = net.flip;
args.jitter = jitter;
args.num_boxes = l.max_boxes;
args.small_object = net.small_object;
args.d = &buffer;
args.type = DETECTION_DATA;
args.threads = 64; // 16 or 64
@ -175,6 +178,14 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int img_size = 1000;
img = draw_train_chart(max_img_loss, net.max_batches, number_of_lines, img_size, dont_show);
#endif //OPENCV
if (net.track) {
args.track = net.track;
args.augment_speed = net.augment_speed;
args.threads = net.subdivisions * ngpus; // 2 * ngpus;
args.mini_batch = net.batch / net.time_steps;
printf("\n Tracking! batch = %d, subdiv = %d, time_steps = %d, mini_batch = %d \n", net.batch, net.subdivisions, net.time_steps, args.mini_batch);
}
//printf(" imgs = %d \n", imgs);
pthread_t load_thread = load_data(args);
double time;
@ -183,14 +194,6 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
while (get_current_batch(net) < net.max_batches) {
if (l.random && count++ % 10 == 0) {
printf("Resizing\n");
//int dim = (rand() % 12 + (init_w/32 - 5)) * 32; // +-160
//int dim = (rand() % 4 + 16) * 32;
//if (get_current_batch(net)+100 > net.max_batches) dim = 544;
//int random_val = rand() % 12;
//int dim_w = (random_val + (init_w / 32 - 5)) * 32; // +-160
//int dim_h = (random_val + (init_h / 32 - 5)) * 32; // +-160
float random_val = rand_scale(1.4); // *x or /x
int dim_w = roundl(random_val*init_w / 32 + 1) * 32;
int dim_h = roundl(random_val*init_h / 32 + 1) * 32;
@ -259,11 +262,13 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
i = get_current_batch(net);
int calc_map_for_each = iter_map + 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
calc_map_for_each = fmax(calc_map_for_each, net.burn_in);
calc_map_for_each = fmax(calc_map_for_each, 1000);
int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
calc_map_for_each = fmax(calc_map_for_each, 100);
int next_map_calc = iter_map + calc_map_for_each;
next_map_calc = fmax(next_map_calc, net.burn_in);
next_map_calc = fmax(next_map_calc, 1000);
if (calc_map) {
printf("\n (next mAP calculation at %d iterations) ", calc_map_for_each);
printf("\n (next mAP calculation at %d iterations) ", next_map_calc);
if (mean_average_precision > 0) printf("\n Last accuracy mAP@0.5 = %2.2f %% ", mean_average_precision * 100);
}
@ -274,7 +279,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), i*imgs);
int draw_precision = 0;
if (calc_map && (i >= calc_map_for_each || i == net.max_batches)) {
if (calc_map && (i >= next_map_calc || i == net.max_batches)) {
if (l.random) {
printf("Resizing to initial size: %d x %d \n", init_w, init_h);
args.w = init_w;
@ -289,11 +294,13 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
net = nets[0];
}
copy_weights_net(net, &net_map);
// combine Training and Validation networks
network net_combined = combine_train_valid_networks(net, net_map);
//network net_combined = combine_train_valid_networks(net, net_map);
iter_map = i;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net_combined);
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
draw_precision = 1;
}
@ -351,6 +358,11 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
for (i = 0; i < ngpus; ++i) free_network(nets[i]);
free(nets);
//free_network(net);
if (calc_map) {
net_map.n = 0;
free_network(net_map);
}
}
@ -443,7 +455,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
int *map = 0;
if (mapf) map = read_map(mapf);
network net = parse_network_cfg_custom(cfgfile, 1, 0); // set batch=1
network net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}
@ -568,7 +580,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile)
{
network net = parse_network_cfg_custom(cfgfile, 1, 0); // set batch=1
network net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}
@ -682,7 +694,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
net = *existing_net;
}
else {
net = parse_network_cfg_custom(cfgfile, 1, 0); // set batch=1
net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}
@ -1251,7 +1263,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
image **alphabet = load_alphabet();
network net = parse_network_cfg_custom(cfgfile, 1, 0); // set batch=1
network net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}

@ -4,35 +4,45 @@
void free_layer(layer l)
{
if (l.type == DROPOUT) {
if (l.rand) free(l.rand);
// free layers: input_layer, self_layer, output_layer, ...
if (l.type == CRNN) {
if (l.input_layer) free_layer(*l.input_layer);
if (l.self_layer) free_layer(*l.self_layer);
if (l.output_layer) free_layer(*l.output_layer);
l.output = NULL;
l.delta = NULL;
l.output_gpu = NULL;
l.delta_gpu = NULL;
}
if (l.type == DROPOUT) {
if (l.rand) free(l.rand);
#ifdef GPU
if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.rand_gpu) cuda_free(l.rand_gpu);
#endif
return;
}
if (l.mask) free(l.mask);
if (l.cweights) free(l.cweights);
if (l.indexes) free(l.indexes);
if (l.input_layers) free(l.input_layers);
if (l.input_sizes) free(l.input_sizes);
if (l.map) free(l.map);
if (l.rand) free(l.rand);
if (l.cost) free(l.cost);
if (l.state) free(l.state);
if (l.prev_state) free(l.prev_state);
if (l.forgot_state) free(l.forgot_state);
if (l.forgot_delta) free(l.forgot_delta);
if (l.state_delta) free(l.state_delta);
if (l.concat) free(l.concat);
if (l.concat_delta) free(l.concat_delta);
if (l.binary_weights) free(l.binary_weights);
if (l.biases) free(l.biases);
if (l.bias_updates) free(l.bias_updates);
if (l.scales) free(l.scales);
if (l.scale_updates) free(l.scale_updates);
if (l.weights) free(l.weights);
if (l.weight_updates) free(l.weight_updates);
return;
}
if (l.mask) free(l.mask);
if (l.cweights) free(l.cweights);
if (l.indexes) free(l.indexes);
if (l.input_layers) free(l.input_layers);
if (l.input_sizes) free(l.input_sizes);
if (l.map) free(l.map);
if (l.rand) free(l.rand);
if (l.cost) free(l.cost);
if (l.state) free(l.state);
if (l.prev_state) free(l.prev_state);
if (l.forgot_state) free(l.forgot_state);
if (l.forgot_delta) free(l.forgot_delta);
if (l.state_delta) free(l.state_delta);
if (l.concat) free(l.concat);
if (l.concat_delta) free(l.concat_delta);
if (l.binary_weights) free(l.binary_weights);
if (l.biases) free(l.biases);
if (l.bias_updates) free(l.bias_updates);
if (l.scales) free(l.scales);
if (l.scale_updates) free(l.scale_updates);
if (l.weights) free(l.weights);
if (l.weight_updates) free(l.weight_updates);
if (l.align_bit_weights) free(l.align_bit_weights);
if (l.mean_arr) free(l.mean_arr);
#ifdef GPU
@ -45,76 +55,76 @@ void free_layer(layer l)
l.output = NULL;
}
#endif // GPU
if (l.delta) free(l.delta);
if (l.output) free(l.output);
if (l.squared) free(l.squared);
if (l.norms) free(l.norms);
if (l.spatial_mean) free(l.spatial_mean);
if (l.mean) free(l.mean);
if (l.variance) free(l.variance);
if (l.mean_delta) free(l.mean_delta);
if (l.variance_delta) free(l.variance_delta);
if (l.rolling_mean) free(l.rolling_mean);
if (l.rolling_variance) free(l.rolling_variance);
if (l.x) free(l.x);
if (l.x_norm) free(l.x_norm);
if (l.m) free(l.m);
if (l.v) free(l.v);
if (l.z_cpu) free(l.z_cpu);
if (l.r_cpu) free(l.r_cpu);
if (l.h_cpu) free(l.h_cpu);
if (l.binary_input) free(l.binary_input);
if (l.delta) free(l.delta);
if (l.output) free(l.output);
if (l.squared) free(l.squared);
if (l.norms) free(l.norms);
if (l.spatial_mean) free(l.spatial_mean);
if (l.mean) free(l.mean);
if (l.variance) free(l.variance);
if (l.mean_delta) free(l.mean_delta);
if (l.variance_delta) free(l.variance_delta);
if (l.rolling_mean) free(l.rolling_mean);
if (l.rolling_variance) free(l.rolling_variance);
if (l.x) free(l.x);
if (l.x_norm) free(l.x_norm);
if (l.m) free(l.m);
if (l.v) free(l.v);
if (l.z_cpu) free(l.z_cpu);
if (l.r_cpu) free(l.r_cpu);
if (l.h_cpu) free(l.h_cpu);
if (l.binary_input) free(l.binary_input);
if (l.bin_re_packed_input) free(l.bin_re_packed_input);
if (l.t_bit_input) free(l.t_bit_input);
if (l.loss) free(l.loss);
#ifdef GPU
if (l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
if (l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
if (l.z_gpu) cuda_free(l.z_gpu);
if (l.r_gpu) cuda_free(l.r_gpu);
if (l.h_gpu) cuda_free(l.h_gpu);
if (l.m_gpu) cuda_free(l.m_gpu);
if (l.v_gpu) cuda_free(l.v_gpu);
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
if (l.forgot_state_gpu) cuda_free(l.forgot_state_gpu);
if (l.forgot_delta_gpu) cuda_free(l.forgot_delta_gpu);
if (l.state_gpu) cuda_free(l.state_gpu);
if (l.state_delta_gpu) cuda_free(l.state_delta_gpu);
if (l.gate_gpu) cuda_free(l.gate_gpu);
if (l.gate_delta_gpu) cuda_free(l.gate_delta_gpu);
if (l.save_gpu) cuda_free(l.save_gpu);
if (l.save_delta_gpu) cuda_free(l.save_delta_gpu);
if (l.concat_gpu) cuda_free(l.concat_gpu);
if (l.concat_delta_gpu) cuda_free(l.concat_delta_gpu);
if (l.binary_input_gpu) cuda_free(l.binary_input_gpu);
if (l.binary_weights_gpu) cuda_free(l.binary_weights_gpu);
if (l.mean_gpu) cuda_free(l.mean_gpu);
if (l.variance_gpu) cuda_free(l.variance_gpu);
if (l.rolling_mean_gpu) cuda_free(l.rolling_mean_gpu);
if (l.rolling_variance_gpu) cuda_free(l.rolling_variance_gpu);
if (l.variance_delta_gpu) cuda_free(l.variance_delta_gpu);
if (l.mean_delta_gpu) cuda_free(l.mean_delta_gpu);
if (l.x_gpu) cuda_free(l.x_gpu);
if (l.x_norm_gpu) cuda_free(l.x_norm_gpu);
if (l.z_gpu) cuda_free(l.z_gpu);
if (l.r_gpu) cuda_free(l.r_gpu);
if (l.h_gpu) cuda_free(l.h_gpu);
if (l.m_gpu) cuda_free(l.m_gpu);
if (l.v_gpu) cuda_free(l.v_gpu);
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
if (l.forgot_state_gpu) cuda_free(l.forgot_state_gpu);
if (l.forgot_delta_gpu) cuda_free(l.forgot_delta_gpu);
if (l.state_gpu) cuda_free(l.state_gpu);
if (l.state_delta_gpu) cuda_free(l.state_delta_gpu);
if (l.gate_gpu) cuda_free(l.gate_gpu);
if (l.gate_delta_gpu) cuda_free(l.gate_delta_gpu);
if (l.save_gpu) cuda_free(l.save_gpu);
if (l.save_delta_gpu) cuda_free(l.save_delta_gpu);
if (l.concat_gpu) cuda_free(l.concat_gpu);
if (l.concat_delta_gpu) cuda_free(l.concat_delta_gpu);
if (l.binary_input_gpu) cuda_free(l.binary_input_gpu);
if (l.binary_weights_gpu) cuda_free(l.binary_weights_gpu);
if (l.mean_gpu) cuda_free(l.mean_gpu);
if (l.variance_gpu) cuda_free(l.variance_gpu);
if (l.rolling_mean_gpu) cuda_free(l.rolling_mean_gpu);
if (l.rolling_variance_gpu) cuda_free(l.rolling_variance_gpu);
if (l.variance_delta_gpu) cuda_free(l.variance_delta_gpu);
if (l.mean_delta_gpu) cuda_free(l.mean_delta_gpu);
if (l.x_gpu) cuda_free(l.x_gpu); // dont free
if (l.x_norm_gpu) cuda_free(l.x_norm_gpu);
if (l.align_bit_weights_gpu) cuda_free((float *)l.align_bit_weights_gpu);
if (l.mean_arr_gpu) cuda_free(l.mean_arr_gpu);
if (l.align_workspace_gpu) cuda_free(l.align_workspace_gpu);
if (l.transposed_align_workspace_gpu) cuda_free(l.transposed_align_workspace_gpu);
if (l.weights_gpu) cuda_free(l.weights_gpu);
if (l.weight_updates_gpu) cuda_free(l.weight_updates_gpu);
if (l.weights_gpu16) cuda_free(l.weights_gpu16);
if (l.weight_updates_gpu16) cuda_free(l.weight_updates_gpu16);
if (l.biases_gpu) cuda_free(l.biases_gpu);
if (l.bias_updates_gpu) cuda_free(l.bias_updates_gpu);
if (l.scales_gpu) cuda_free(l.scales_gpu);
if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu);
if (l.output_gpu) cuda_free(l.output_gpu);
if (l.delta_gpu) cuda_free(l.delta_gpu);
if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.squared_gpu) cuda_free(l.squared_gpu);
if (l.norms_gpu) cuda_free(l.norms_gpu);
if (l.weights_gpu) cuda_free(l.weights_gpu);
if (l.weight_updates_gpu) cuda_free(l.weight_updates_gpu);
if (l.weights_gpu16) cuda_free(l.weights_gpu16);
if (l.weight_updates_gpu16) cuda_free(l.weight_updates_gpu16);
if (l.biases_gpu) cuda_free(l.biases_gpu);
if (l.bias_updates_gpu) cuda_free(l.bias_updates_gpu);
if (l.scales_gpu) cuda_free(l.scales_gpu);
if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu);
if (l.output_gpu) cuda_free(l.output_gpu);
if (l.delta_gpu) cuda_free(l.delta_gpu);
if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.squared_gpu) cuda_free(l.squared_gpu);
if (l.norms_gpu) cuda_free(l.norms_gpu);
#endif
}

@ -452,6 +452,9 @@ int resize_network(network *net, int w, int h)
//printf(" %d: layer = %d,", i, l.type);
if(l.type == CONVOLUTIONAL){
resize_convolutional_layer(&l, w, h);
}
else if (l.type == CRNN) {
resize_crnn_layer(&l, w, h);
}else if(l.type == CROP){
resize_crop_layer(&l, w, h);
}else if(l.type == MAXPOOL){
@ -1018,6 +1021,68 @@ void calculate_binary_weights(network net)
}
void copy_cudnn_descriptors(layer src, layer *dst)
{
dst->normTensorDesc = src.normTensorDesc;
dst->normDstTensorDesc = src.normDstTensorDesc;
dst->normDstTensorDescF16 = src.normDstTensorDescF16;
dst->srcTensorDesc = src.srcTensorDesc;
dst->dstTensorDesc = src.dstTensorDesc;
dst->srcTensorDesc16 = src.srcTensorDesc16;
dst->dstTensorDesc16 = src.dstTensorDesc16;
//dst->batch = 1;
//dst->steps = 1;
}
void copy_weights_pointers_gpu(layer src, layer *dst)
{
dst->weights_gpu = src.weights_gpu;
dst->weights_gpu16 = src.weights_gpu16;
dst->biases_gpu = src.biases_gpu;
dst->scales_gpu = src.scales_gpu;
dst->rolling_mean_gpu = src.rolling_mean_gpu;
dst->rolling_variance_gpu = src.rolling_variance_gpu;
dst->mean_gpu = src.mean_gpu;
dst->variance_gpu = src.variance_gpu;
//dst->align_bit_weights_gpu = src.align_bit_weights_gpu;
dst->x_gpu = src.x_gpu;
dst->output_gpu = src.output_gpu;
}
void copy_weights_net(network net_train, network *net_map)
{
int k;
for (k = 0; k < net_train.n; ++k) {
layer *l = &(net_train.layers[k]);
layer tmp_layer;
copy_cudnn_descriptors(net_map->layers[k], &tmp_layer);
net_map->layers[k] = net_train.layers[k];
copy_cudnn_descriptors(tmp_layer, &net_map->layers[k]);
if (l->type == CRNN) {
layer tmp_input_layer, tmp_self_layer, tmp_output_layer;
copy_cudnn_descriptors(*net_map->layers[k].input_layer, &tmp_input_layer);
copy_cudnn_descriptors(*net_map->layers[k].self_layer, &tmp_self_layer);
copy_cudnn_descriptors(*net_map->layers[k].output_layer, &tmp_output_layer);
net_map->layers[k].input_layer = net_train.layers[k].input_layer;
net_map->layers[k].self_layer = net_train.layers[k].self_layer;
net_map->layers[k].output_layer = net_train.layers[k].output_layer;
//net_map->layers[k].output_gpu = net_map->layers[k].output_layer->output_gpu; // already copied out of if()
copy_cudnn_descriptors(tmp_input_layer, net_map->layers[k].input_layer);
copy_cudnn_descriptors(tmp_self_layer, net_map->layers[k].self_layer);
copy_cudnn_descriptors(tmp_output_layer, net_map->layers[k].output_layer);
}
net_map->layers[k].batch = 1;
net_map->layers[k].steps = 1;
}
}
// combine Training and Validation networks
network combine_train_valid_networks(network net_train, network net_map)
{
@ -1026,26 +1091,48 @@ network combine_train_valid_networks(network net_train, network net_map)
net_combined = net_train;
net_combined.layers = old_layers;
net_combined.batch = 1;
net_combined.time_steps = 1;
int k;
for (k = 0; k < net_train.n; ++k) {
layer *l = &(net_train.layers[k]);
net_combined.layers[k] = net_train.layers[k];
net_combined.layers[k].batch = 1;
if (l->type == CONVOLUTIONAL) {
#ifdef CUDNN
net_combined.layers[k].normTensorDesc = net_map.layers[k].normTensorDesc;
net_combined.layers[k].normDstTensorDesc = net_map.layers[k].normDstTensorDesc;
net_combined.layers[k].normDstTensorDescF16 = net_map.layers[k].normDstTensorDescF16;
if (l->type == CONVOLUTIONAL) {
/*
net_combined.layers[k] = net_train.layers[k];
net_combined.layers[k].batch = 1;
net_combined.layers[k].steps = 1;
copy_cudnn_descriptors(net_map.layers[k], &net_combined.layers[k]);
*/
net_combined.layers[k] = net_map.layers[k];
//net_combined.layers[k] = net_train.layers[k];
net_combined.layers[k].batch = 1;
net_combined.layers[k].steps = 1;
net_combined.layers[k].srcTensorDesc = net_map.layers[k].srcTensorDesc;
net_combined.layers[k].dstTensorDesc = net_map.layers[k].dstTensorDesc;
copy_weights_pointers_gpu(net_train.layers[k], &net_combined.layers[k]);
net_combined.layers[k].srcTensorDesc16 = net_map.layers[k].srcTensorDesc16;
net_combined.layers[k].dstTensorDesc16 = net_map.layers[k].dstTensorDesc16;
#endif // CUDNN
net_combined.layers[k].output_gpu = net_train.layers[k].output_gpu;
}
else if (l->type == CRNN) {
net_combined.layers[k] = net_map.layers[k];
net_combined.layers[k].batch = 1;
net_combined.layers[k].steps = 1;
// Don't use copy_cudnn_descriptors() here
copy_weights_pointers_gpu(*net_train.layers[k].input_layer, net_combined.layers[k].input_layer);
copy_weights_pointers_gpu(*net_train.layers[k].self_layer, net_combined.layers[k].self_layer);
copy_weights_pointers_gpu(*net_train.layers[k].output_layer, net_combined.layers[k].output_layer);
net_combined.layers[k].output_gpu = net_combined.layers[k].output_layer->output_gpu;
}
else {
net_combined.layers[k] = net_train.layers[k];
net_combined.layers[k].batch = 1;
net_combined.layers[k].steps = 1;
}
#endif // CUDNN
}
return net_combined;
}

@ -161,6 +161,7 @@ int get_network_background(network net);
//LIB_API void fuse_conv_batchnorm(network net);
//LIB_API void calculate_binary_weights(network net);
network combine_train_valid_networks(network net_train, network net_map);
void copy_weights_net(network net_train, network *net_map);
#ifdef __cplusplus
}

@ -163,9 +163,16 @@ void forward_backward_network_gpu(network net, float *x, float *y)
int i;
for (i = 0; i < net.n; ++i) {
layer l = net.layers[i];
if (l.weights_gpu && l.weights_gpu16) {
if (l.weights_gpu && l.weights_gpu16 && net.cudnn_half){
assert((l.c*l.n*l.size*l.size) > 0);
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
if (l.type == CONVOLUTIONAL) {
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
}
else if (l.type == CRNN) {
//cuda_convert_f32_to_f16(l.input_layer->weights_gpu, l.input_layer->nweights, l.input_layer->weights_gpu16);
//cuda_convert_f32_to_f16(l.self_layer->weights_gpu, l.self_layer->nweights, l.self_layer->weights_gpu16);
//cuda_convert_f32_to_f16(l.output_layer->weights_gpu, l.output_layer->nweights, l.output_layer->weights_gpu16);
}
}
}
#endif

@ -184,15 +184,18 @@ layer parse_crnn(list *options, size_params params)
{
int size = option_find_int_quiet(options, "size", 3);
int stride = option_find_int_quiet(options, "stride", 1);
int pad = option_find_int_quiet(options, "pad", 1);
int pad = option_find_int_quiet(options, "pad", 0);
int padding = option_find_int_quiet(options, "padding", 0);
if (pad) padding = size / 2;
int output_filters = option_find_int(options, "output",1);
int hidden_filters = option_find_int(options, "hidden",1);
char *activation_s = option_find_str(options, "activation", "logistic");
ACTIVATION activation = get_activation(activation_s);
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
layer l = make_crnn_layer(params.batch, params.w, params.h, params.c, hidden_filters, output_filters, params.time_steps, size, stride, pad, activation, batch_normalize);
layer l = make_crnn_layer(params.batch, params.w, params.h, params.c, hidden_filters, output_filters, params.time_steps, size, stride, padding, activation, batch_normalize, xnor);
l.shortcut = option_find_int_quiet(options, "shortcut", 0);
@ -638,6 +641,9 @@ void parse_net_options(list *options, network *net)
net->decay = option_find_float(options, "decay", .0001);
int subdivs = option_find_int(options, "subdivisions",1);
net->time_steps = option_find_int_quiet(options, "time_steps",1);
net->track = option_find_int_quiet(options, "track", 0);
net->augment_speed = option_find_int_quiet(options, "augment_speed", 2);
net->try_fix_nan = option_find_int_quiet(options, "try_fix_nan", 0);
net->batch /= subdivs;
net->batch *= net->time_steps;
net->subdivisions = subdivs;
@ -657,7 +663,6 @@ void parse_net_options(list *options, network *net)
net->min_crop = option_find_int_quiet(options, "min_crop",net->w);
net->flip = option_find_int_quiet(options, "flip", 1);
net->small_object = option_find_int_quiet(options, "small_object", 0);
net->angle = option_find_float_quiet(options, "angle", 0);
net->aspect = option_find_float_quiet(options, "aspect", 1);
net->saturation = option_find_float_quiet(options, "saturation", 1);
@ -748,6 +753,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
params.inputs = net.inputs;
if (batch > 0) net.batch = batch;
if (time_steps > 0) net.time_steps = time_steps;
if (net.batch < net.time_steps) net.batch = net.time_steps;
params.batch = net.batch;
params.time_steps = net.time_steps;
params.net = net;
@ -984,10 +990,10 @@ void save_convolutional_weights(layer l, FILE *fp)
fwrite(l.rolling_variance, sizeof(float), l.n, fp);
}
fwrite(l.weights, sizeof(float), num, fp);
if(l.adam){
fwrite(l.m, sizeof(float), num, fp);
fwrite(l.v, sizeof(float), num, fp);
}
//if(l.adam){
// fwrite(l.m, sizeof(float), num, fp);
// fwrite(l.v, sizeof(float), num, fp);
//}
}
void save_batchnorm_weights(layer l, FILE *fp)
@ -1198,10 +1204,10 @@ void load_convolutional_weights(layer l, FILE *fp)
}
}
fread(l.weights, sizeof(float), num, fp);
if(l.adam){
fread(l.m, sizeof(float), num, fp);
fread(l.v, sizeof(float), num, fp);
}
//if(l.adam){
// fread(l.m, sizeof(float), num, fp);
// fread(l.v, sizeof(float), num, fp);
//}
//if(l.c == 3) scal_cpu(num, 1./256, l.weights, 1);
if (l.flipped) {
transpose_matrix(l.weights, l.c*l.size*l.size, l.n);

Loading…
Cancel
Save