Some conv-lstm, sgdr and other fixes

pull/749/head^2
AlexeyAB 6 years ago
parent 4f72fcc015
commit 038289eb7d
  1. 2
      README.md
  2. 2
      include/darknet.h
  3. 47
      src/activation_kernels.cu
  4. 2
      src/blas_kernels.cu
  5. 6
      src/conv_lstm_layer.c
  6. 11
      src/convolutional_layer.c
  7. 6
      src/detector.c
  8. 2
      src/image_opencv.cpp
  9. 15
      src/network.c
  10. 32
      src/parser.c

@ -518,7 +518,7 @@ Example of custom object detection: `darknet.exe detector test data/obj.data yol
* increase network resolution in your `.cfg`-file (`height=608`, `width=608` or any value multiple of 32) - it will increase precision
* check that each object is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark
* check that each object that you want to detect is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark
* for each object which you want to detect - there must be at least 1 similar object in the Training dataset with about the same: shape, side of object, relative size, angle of rotation, tilt, illumination. So desirable that your training dataset include images with objects at diffrent: scales, rotations, lightings, from different sides, on different backgrounds - you should preferably have 2000 different images for each class or more, and you should train `2000*classes` iterations or more

@ -211,6 +211,7 @@ struct layer {
int peephole;
int use_bin_output;
int steps;
int state_constrain;
int hidden;
int truth;
float smooth;
@ -551,6 +552,7 @@ typedef struct network {
float learning_rate_min;
float learning_rate_max;
int batches_per_cycle;
int batches_cycle_mult;
float momentum;
float decay;
float gamma;

@ -210,6 +210,30 @@ __global__ void activate_array_logistic_kernel(float *x, int n)
}
}
__global__ void activate_array_tanh_kernel(float *x, int n)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < n) {
x[index] = tanh_activate_kernel(x[index]);
}
}
__global__ void activate_array_hardtan_kernel(float *x, int n)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < n) {
x[index] = hardtan_activate_kernel(x[index]);
}
}
__global__ void activate_array_relu_kernel(float *x, int n)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < n) {
x[index] = relu_activate_kernel(x[index]);
}
}
__global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -240,6 +264,14 @@ __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta)
}
}
__global__ void gradient_array_tanh_kernel(float *x, int n, float *delta)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < n) {
delta[index] *= tanh_gradient_kernel(x[index]);
}
}
__global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
@ -248,12 +280,23 @@ __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta)
}
}
__global__ void gradient_array_relu_kernel(float *x, int n, float *delta)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < n) {
delta[index] *= relu_gradient_kernel(x[index]);
}
}
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
{
const int num_blocks = get_number_of_blocks(n, BLOCK);
if (a == LINEAR) return;
else if(a == LEAKY) activate_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == LOGISTIC) activate_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == TANH) activate_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == HARDTAN) activate_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == RELU) activate_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == SELU) activate_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else
activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
@ -266,8 +309,10 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta
if (a == LINEAR) return;
else if (a == LEAKY) gradient_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == LOGISTIC) gradient_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == SELU) gradient_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == TANH) gradient_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == HARDTAN) gradient_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == RELU) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == SELU) gradient_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else
gradient_array_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, a, delta);
CHECK_CUDA(cudaPeekAtLastError());

@ -983,7 +983,7 @@ __global__ void fix_nan_and_inf_kernel(float *input, size_t size)
if (index < size) {
float val = input[index];
if (isnan(val) || isinf(val))
input[index] = index; // pseudo random value
input[index] = 1.0f / index; // pseudo random value
}
}

@ -902,13 +902,14 @@ void forward_conv_lstm_layer_gpu(layer l, network_state state)
activate_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH);
mul_ongpu(l.outputs*l.batch, l.o_gpu, 1, l.h_gpu, 1);
if(l.state_constrain) constrain_ongpu(l.outputs*l.batch, l.state_constrain, l.c_gpu, 1);
//constrain_ongpu(l.outputs*l.batch, 1, l.c_gpu, 1);
//constrain_ongpu(l.outputs*l.batch, 1, l.h_gpu, 1);
fix_nan_and_inf(l.c_gpu, l.outputs*l.batch);
fix_nan_and_inf(l.h_gpu, l.outputs*l.batch);
copy_ongpu(l.outputs*l.batch, l.c_gpu, 1, l.cell_gpu, 1);
copy_ongpu(l.outputs*l.batch, l.h_gpu, 1, l.output_gpu, 1);
copy_ongpu(l.outputs*l.batch, l.h_gpu, 1, l.output_gpu, 1); // required for both Detection and Training
state.input += l.inputs*l.batch;
l.output_gpu += l.outputs*l.batch;
@ -1159,5 +1160,8 @@ void backward_conv_lstm_layer_gpu(layer l, network_state state)
copy_ongpu(l.outputs*l.batch, last_output, 1, l.last_prev_state_gpu, 1);
copy_ongpu(l.outputs*l.batch, last_cell, 1, l.last_prev_cell_gpu, 1);
// free state after each 100 iterations
//if (get_current_batch(state.net) % 100) free_state_conv_lstm(l); // dont use
}
#endif

@ -1046,13 +1046,9 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
}
}
void update_convolutional_layer(convolutional_layer l, update_args a)
void update_convolutional_layer(convolutional_layer l, int batch, float learning_rate, float momentum, float decay)
{
float learning_rate = a.learning_rate*l.learning_rate_scale;
float momentum = a.momentum;
float decay = a.decay;
int batch = a.batch;
//int size = l.size*l.size*l.c*l.n;
axpy_cpu(l.n, learning_rate / batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.n, momentum, l.bias_updates, 1);
@ -1067,6 +1063,7 @@ void update_convolutional_layer(convolutional_layer l, update_args a)
}
image get_convolutional_weight(convolutional_layer l, int i)
{
int h = l.size;
@ -1101,7 +1098,7 @@ void rescale_weights(convolutional_layer l, float scale, float trans)
image *get_weights(convolutional_layer l)
{
image *weights = calloc(l.n, sizeof(image));
image *weights = (image *)calloc(l.n, sizeof(image));
int i;
for (i = 0; i < l.n; ++i) {
weights[i] = copy_image(get_convolutional_weight(l, i));

@ -666,7 +666,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char *train_images = option_find_str(options, "train", "data/train.txt");
valid_images = option_find_str(options, "valid", train_images);
net = *existing_net;
//remember_network_recurrent_state(*existing_net);
remember_network_recurrent_state(*existing_net);
free_network_recurrent_state(*existing_net);
}
else {
@ -1077,8 +1077,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
if (existing_net) {
//set_batch_network(&net, initial_batch);
//free_network_recurrent_state(*existing_net);
//restore_network_recurrent_state(*existing_net);
randomize_network_recurrent_state(*existing_net);
restore_network_recurrent_state(*existing_net);
//randomize_network_recurrent_state(*existing_net);
}
else {
free_network(net);

@ -1081,7 +1081,7 @@ void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_im
cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 5, CV_AA);
cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(200, 0, 0), 1, CV_AA);
if (((int)(old_precision * 10) != (int)(precision * 10)) || (max_precision < precision) || (current_batch - text_iteration_old) >= max_batches / 10) {
if ((std::fabs(old_precision - precision) > 0.1) || (max_precision < precision) || (current_batch - text_iteration_old) >= max_batches / 10) {
text_iteration_old = current_batch;
max_precision = std::max(max_precision, precision);
sprintf(char_buff, "%2.0f%% ", precision * 100);

@ -95,7 +95,7 @@ float get_current_seq_subdivisions(network net)
{
int sequence_subdivisions = net.init_sequential_subdivisions;
if (net.policy)
if (net.num_steps > 0)
{
int batch_num = get_current_batch(net);
int i;
@ -145,11 +145,20 @@ float get_current_rate(network net)
case SIG:
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
case SGDR:
{
int last_iteration_start = 0;
int cycle_size = net.batches_per_cycle;
while ((last_iteration_start + cycle_size) < batch_num)
{
last_iteration_start += cycle_size;
cycle_size *= net.batches_cycle_mult;
}
rate = net.learning_rate_min +
0.5*(net.learning_rate-net.learning_rate_min)
* (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle));
0.5*(net.learning_rate - net.learning_rate_min)
* (1. + cos((float)(batch_num - last_iteration_start)*3.14159265 / cycle_size));
return rate;
}
default:
fprintf(stderr, "Policy is weird!\n");
return net.learning_rate;

@ -200,7 +200,7 @@ layer parse_crnn(list *options, size_params params)
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
layer l = make_crnn_layer(params.batch, params.w, params.h, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor);
layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor);
l.shortcut = option_find_int_quiet(options, "shortcut", 0);
@ -260,8 +260,9 @@ layer parse_conv_lstm(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0);
int peephole = option_find_int_quiet(options, "peephole", 1);
layer l = make_conv_lstm_layer(params.batch, params.w, params.h, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor);
layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor);
l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32);
l.shortcut = option_find_int_quiet(options, "shortcut", 0);
return l;
@ -668,7 +669,8 @@ void parse_net_options(list *options, network *net)
net->batch = option_find_int(options, "batch",1);
net->learning_rate = option_find_float(options, "learning_rate", .001);
net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001);
net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 500);
net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 1000);
net->batches_cycle_mult = option_find_int_quiet(options, "sgdr_mult", 2);
net->momentum = option_find_float(options, "momentum", .9);
net->decay = option_find_float(options, "decay", .0001);
int subdivs = option_find_int(options, "subdivisions",1);
@ -721,40 +723,44 @@ void parse_net_options(list *options, network *net)
if(net->policy == STEP){
net->step = option_find_int(options, "step", 1);
net->scale = option_find_float(options, "scale", 1);
} else if (net->policy == STEPS){
} else if (net->policy == STEPS || net->policy == SGDR){
char *l = option_find(options, "steps");
char *p = option_find(options, "scales");
char *s = option_find(options, "seq_scales");
if(!l || !p) error("STEPS policy must have steps and scales in cfg file");
if(net->policy == STEPS && (!l || !p)) error("STEPS policy must have steps and scales in cfg file");
if (l) {
int len = strlen(l);
int n = 1;
int i;
for(i = 0; i < len; ++i){
for (i = 0; i < len; ++i) {
if (l[i] == ',') ++n;
}
int* steps = (int*)calloc(n, sizeof(int));
float* scales = (float*)calloc(n, sizeof(float));
float* seq_scales = (float*)calloc(n, sizeof(float));
for (i = 0; i < n; ++i) {
float scale = 1.0;
if (p) {
scale = atof(p);
p = strchr(p, ',') + 1;
}
float sequence_scale = 1.0;
if (s) {
seq_scales[i] = atof(s);
sequence_scale = atof(s);
s = strchr(s, ',') + 1;
} else
seq_scales[i] = 1;
}
for(i = 0; i < n; ++i){
int step = atoi(l);
float scale = atof(p);
l = strchr(l, ',')+1;
p = strchr(p, ',')+1;
l = strchr(l, ',') + 1;
steps[i] = step;
scales[i] = scale;
seq_scales[i] = sequence_scale;
}
net->scales = scales;
net->steps = steps;
net->seq_scales = seq_scales;
net->num_steps = n;
}
} else if (net->policy == EXP){
net->gamma = option_find_float(options, "gamma", 1);
} else if (net->policy == SIG){

Loading…
Cancel
Save