diff --git a/README.md b/README.md index 92cfe1a5..47ec2752 100644 --- a/README.md +++ b/README.md @@ -518,9 +518,9 @@ Example of custom object detection: `darknet.exe detector test data/obj.data yol * increase network resolution in your `.cfg`-file (`height=608`, `width=608` or any value multiple of 32) - it will increase precision - * check that each object is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark + * check that each object that you want to detect is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark - * for each object which you want to detect - there must be at least 1 similar object in the Training dataset with about the same: shape, side of object, relative size, angle of rotation, tilt, illumination. So desirable that your training dataset include images with objects at diffrent: scales, rotations, lightings, from different sides, on different backgrounds - you should preferably have 2000 different images for each class or more, and you should train `2000*classes` iterations or more + * for each object which you want to detect - there must be at least 1 similar object in the Training dataset with about the same: shape, side of object, relative size, angle of rotation, tilt, illumination. So desirable that your training dataset include images with objects at diffrent: scales, rotations, lightings, from different sides, on different backgrounds - you should preferably have 2000 different images for each class or more, and you should train `2000*classes` iterations or more * desirable that your training dataset include images with non-labeled objects that you do not want to detect - negative samples without bounded box (empty `.txt` files) - use as many images of negative samples as there are images with objects diff --git a/include/darknet.h b/include/darknet.h index fbd7e27d..68ab1fc7 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -211,6 +211,7 @@ struct layer { int peephole; int use_bin_output; int steps; + int state_constrain; int hidden; int truth; float smooth; @@ -551,6 +552,7 @@ typedef struct network { float learning_rate_min; float learning_rate_max; int batches_per_cycle; + int batches_cycle_mult; float momentum; float decay; float gamma; diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 6c9445a6..4e5fa9f8 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -210,6 +210,30 @@ __global__ void activate_array_logistic_kernel(float *x, int n) } } +__global__ void activate_array_tanh_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = tanh_activate_kernel(x[index]); + } +} + +__global__ void activate_array_hardtan_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = hardtan_activate_kernel(x[index]); + } +} + +__global__ void activate_array_relu_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = relu_activate_kernel(x[index]); + } +} + __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -240,6 +264,14 @@ __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta) } } +__global__ void gradient_array_tanh_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= tanh_gradient_kernel(x[index]); + } +} + __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -248,12 +280,23 @@ __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta) } } +__global__ void gradient_array_relu_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= relu_gradient_kernel(x[index]); + } +} + extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { const int num_blocks = get_number_of_blocks(n, BLOCK); if (a == LINEAR) return; else if(a == LEAKY) activate_array_leaky_kernel << > >(x, n); else if (a == LOGISTIC) activate_array_logistic_kernel << > >(x, n); + else if (a == TANH) activate_array_tanh_kernel << > >(x, n); + else if (a == HARDTAN) activate_array_hardtan_kernel << > >(x, n); + else if (a == RELU) activate_array_relu_kernel << > >(x, n); else if (a == SELU) activate_array_selu_kernel << > >(x, n); else activate_array_kernel<<>>(x, n, a); @@ -266,8 +309,10 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta if (a == LINEAR) return; else if (a == LEAKY) gradient_array_leaky_kernel << > >(x, n, delta); else if (a == LOGISTIC) gradient_array_logistic_kernel << > >(x, n, delta); - else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); + else if (a == TANH) gradient_array_tanh_kernel << > >(x, n, delta); else if (a == HARDTAN) gradient_array_hardtan_kernel << > >(x, n, delta); + else if (a == RELU) gradient_array_relu_kernel << > >(x, n, delta); + else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); else gradient_array_kernel << > > (x, n, a, delta); CHECK_CUDA(cudaPeekAtLastError()); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 2070bc1f..66905127 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -983,7 +983,7 @@ __global__ void fix_nan_and_inf_kernel(float *input, size_t size) if (index < size) { float val = input[index]; if (isnan(val) || isinf(val)) - input[index] = index; // pseudo random value + input[index] = 1.0f / index; // pseudo random value } } diff --git a/src/conv_lstm_layer.c b/src/conv_lstm_layer.c index 1764cb90..b51a3aed 100644 --- a/src/conv_lstm_layer.c +++ b/src/conv_lstm_layer.c @@ -902,13 +902,14 @@ void forward_conv_lstm_layer_gpu(layer l, network_state state) activate_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH); mul_ongpu(l.outputs*l.batch, l.o_gpu, 1, l.h_gpu, 1); + if(l.state_constrain) constrain_ongpu(l.outputs*l.batch, l.state_constrain, l.c_gpu, 1); //constrain_ongpu(l.outputs*l.batch, 1, l.c_gpu, 1); //constrain_ongpu(l.outputs*l.batch, 1, l.h_gpu, 1); fix_nan_and_inf(l.c_gpu, l.outputs*l.batch); fix_nan_and_inf(l.h_gpu, l.outputs*l.batch); copy_ongpu(l.outputs*l.batch, l.c_gpu, 1, l.cell_gpu, 1); - copy_ongpu(l.outputs*l.batch, l.h_gpu, 1, l.output_gpu, 1); + copy_ongpu(l.outputs*l.batch, l.h_gpu, 1, l.output_gpu, 1); // required for both Detection and Training state.input += l.inputs*l.batch; l.output_gpu += l.outputs*l.batch; @@ -1159,5 +1160,8 @@ void backward_conv_lstm_layer_gpu(layer l, network_state state) copy_ongpu(l.outputs*l.batch, last_output, 1, l.last_prev_state_gpu, 1); copy_ongpu(l.outputs*l.batch, last_cell, 1, l.last_prev_cell_gpu, 1); + + // free state after each 100 iterations + //if (get_current_batch(state.net) % 100) free_state_conv_lstm(l); // dont use } #endif diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index f9c0ee01..d983ab61 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -1046,13 +1046,9 @@ void backward_convolutional_layer(convolutional_layer l, network_state state) } } -void update_convolutional_layer(convolutional_layer l, update_args a) +void update_convolutional_layer(convolutional_layer l, int batch, float learning_rate, float momentum, float decay) { - float learning_rate = a.learning_rate*l.learning_rate_scale; - float momentum = a.momentum; - float decay = a.decay; - int batch = a.batch; - + //int size = l.size*l.size*l.c*l.n; axpy_cpu(l.n, learning_rate / batch, l.bias_updates, 1, l.biases, 1); scal_cpu(l.n, momentum, l.bias_updates, 1); @@ -1067,6 +1063,7 @@ void update_convolutional_layer(convolutional_layer l, update_args a) } + image get_convolutional_weight(convolutional_layer l, int i) { int h = l.size; @@ -1101,7 +1098,7 @@ void rescale_weights(convolutional_layer l, float scale, float trans) image *get_weights(convolutional_layer l) { - image *weights = calloc(l.n, sizeof(image)); + image *weights = (image *)calloc(l.n, sizeof(image)); int i; for (i = 0; i < l.n; ++i) { weights[i] = copy_image(get_convolutional_weight(l, i)); diff --git a/src/detector.c b/src/detector.c index 4fc74585..14ad3618 100644 --- a/src/detector.c +++ b/src/detector.c @@ -666,7 +666,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa char *train_images = option_find_str(options, "train", "data/train.txt"); valid_images = option_find_str(options, "valid", train_images); net = *existing_net; - //remember_network_recurrent_state(*existing_net); + remember_network_recurrent_state(*existing_net); free_network_recurrent_state(*existing_net); } else { @@ -1077,8 +1077,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if (existing_net) { //set_batch_network(&net, initial_batch); //free_network_recurrent_state(*existing_net); - //restore_network_recurrent_state(*existing_net); - randomize_network_recurrent_state(*existing_net); + restore_network_recurrent_state(*existing_net); + //randomize_network_recurrent_state(*existing_net); } else { free_network(net); diff --git a/src/image_opencv.cpp b/src/image_opencv.cpp index 1387b633..d7e9a50b 100644 --- a/src/image_opencv.cpp +++ b/src/image_opencv.cpp @@ -1081,7 +1081,7 @@ void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_im cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 5, CV_AA); cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(200, 0, 0), 1, CV_AA); - if (((int)(old_precision * 10) != (int)(precision * 10)) || (max_precision < precision) || (current_batch - text_iteration_old) >= max_batches / 10) { + if ((std::fabs(old_precision - precision) > 0.1) || (max_precision < precision) || (current_batch - text_iteration_old) >= max_batches / 10) { text_iteration_old = current_batch; max_precision = std::max(max_precision, precision); sprintf(char_buff, "%2.0f%% ", precision * 100); diff --git a/src/network.c b/src/network.c index 32f2b9e2..a46d15a0 100644 --- a/src/network.c +++ b/src/network.c @@ -95,7 +95,7 @@ float get_current_seq_subdivisions(network net) { int sequence_subdivisions = net.init_sequential_subdivisions; - if (net.policy) + if (net.num_steps > 0) { int batch_num = get_current_batch(net); int i; @@ -145,11 +145,20 @@ float get_current_rate(network net) case SIG: return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step)))); case SGDR: + { + int last_iteration_start = 0; + int cycle_size = net.batches_per_cycle; + while ((last_iteration_start + cycle_size) < batch_num) + { + last_iteration_start += cycle_size; + cycle_size *= net.batches_cycle_mult; + } rate = net.learning_rate_min + - 0.5*(net.learning_rate-net.learning_rate_min) - * (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle)); + 0.5*(net.learning_rate - net.learning_rate_min) + * (1. + cos((float)(batch_num - last_iteration_start)*3.14159265 / cycle_size)); return rate; + } default: fprintf(stderr, "Policy is weird!\n"); return net.learning_rate; diff --git a/src/parser.c b/src/parser.c index cca3641b..95c92250 100644 --- a/src/parser.c +++ b/src/parser.c @@ -200,7 +200,7 @@ layer parse_crnn(list *options, size_params params) int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); int xnor = option_find_int_quiet(options, "xnor", 0); - layer l = make_crnn_layer(params.batch, params.w, params.h, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); + layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); l.shortcut = option_find_int_quiet(options, "shortcut", 0); @@ -260,8 +260,9 @@ layer parse_conv_lstm(list *options, size_params params) int xnor = option_find_int_quiet(options, "xnor", 0); int peephole = option_find_int_quiet(options, "peephole", 1); - layer l = make_conv_lstm_layer(params.batch, params.w, params.h, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor); + layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor); + l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32); l.shortcut = option_find_int_quiet(options, "shortcut", 0); return l; @@ -668,7 +669,8 @@ void parse_net_options(list *options, network *net) net->batch = option_find_int(options, "batch",1); net->learning_rate = option_find_float(options, "learning_rate", .001); net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001); - net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 500); + net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 1000); + net->batches_cycle_mult = option_find_int_quiet(options, "sgdr_mult", 2); net->momentum = option_find_float(options, "momentum", .9); net->decay = option_find_float(options, "decay", .0001); int subdivs = option_find_int(options, "subdivisions",1); @@ -721,40 +723,44 @@ void parse_net_options(list *options, network *net) if(net->policy == STEP){ net->step = option_find_int(options, "step", 1); net->scale = option_find_float(options, "scale", 1); - } else if (net->policy == STEPS){ + } else if (net->policy == STEPS || net->policy == SGDR){ char *l = option_find(options, "steps"); char *p = option_find(options, "scales"); char *s = option_find(options, "seq_scales"); - if(!l || !p) error("STEPS policy must have steps and scales in cfg file"); + if(net->policy == STEPS && (!l || !p)) error("STEPS policy must have steps and scales in cfg file"); - int len = strlen(l); - int n = 1; - int i; - for(i = 0; i < len; ++i){ - if (l[i] == ',') ++n; - } - int* steps = (int*)calloc(n, sizeof(int)); - float* scales = (float*)calloc(n, sizeof(float)); - float* seq_scales = (float*)calloc(n, sizeof(float)); - for (i = 0; i < n; ++i) { - if (s) { - seq_scales[i] = atof(s); - s = strchr(s, ',') + 1; - } else - seq_scales[i] = 1; - } - for(i = 0; i < n; ++i){ - int step = atoi(l); - float scale = atof(p); - l = strchr(l, ',')+1; - p = strchr(p, ',')+1; - steps[i] = step; - scales[i] = scale; + if (l) { + int len = strlen(l); + int n = 1; + int i; + for (i = 0; i < len; ++i) { + if (l[i] == ',') ++n; + } + int* steps = (int*)calloc(n, sizeof(int)); + float* scales = (float*)calloc(n, sizeof(float)); + float* seq_scales = (float*)calloc(n, sizeof(float)); + for (i = 0; i < n; ++i) { + float scale = 1.0; + if (p) { + scale = atof(p); + p = strchr(p, ',') + 1; + } + float sequence_scale = 1.0; + if (s) { + sequence_scale = atof(s); + s = strchr(s, ',') + 1; + } + int step = atoi(l); + l = strchr(l, ',') + 1; + steps[i] = step; + scales[i] = scale; + seq_scales[i] = sequence_scale; + } + net->scales = scales; + net->steps = steps; + net->seq_scales = seq_scales; + net->num_steps = n; } - net->scales = scales; - net->steps = steps; - net->seq_scales = seq_scales; - net->num_steps = n; } else if (net->policy == EXP){ net->gamma = option_find_float(options, "gamma", 1); } else if (net->policy == SIG){