Added letter_box=1 param in [net] section (cfg-file) for keeping aspect ratio during training

pull/3501/head
AlexeyAB 6 years ago
parent 1ed71f4b29
commit c9129c2078
  1. 4
      include/darknet.h
  2. 72
      src/data.c
  3. 4
      src/data.h
  4. 19
      src/detector.c
  5. 2
      src/image_opencv.cpp
  6. 2
      src/image_opencv.h
  7. 2
      src/network.h
  8. 1
      src/parser.c

@ -600,6 +600,7 @@ typedef struct network {
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
int blur;
int mixup;
int letter_box;
float angle;
float aspect;
float exposure;
@ -760,6 +761,7 @@ typedef struct load_args {
int mini_batch;
int track;
int augment_speed;
int letter_box;
int show_imgs;
float jitter;
int flip;
@ -827,7 +829,7 @@ LIB_API layer* get_network_layer(network* net, int i);
LIB_API detection *make_network_boxes(network *net, float thresh, int *num);
LIB_API void reset_rnn(network *net);
LIB_API float *network_predict_image(network *net, image im);
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net);
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net);
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs);
LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box);

@ -804,8 +804,8 @@ void blend_truth(float *new_truth, int boxes, float *old_truth)
#include "http_stream.h"
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup,
float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs)
{
const int random_index = random_gen();
c = c ? c : 3;
@ -828,7 +828,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*c;
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
float r1 = 0, r2 = 0, r3 = 0, r4 = 0, r_scale = 0;
float dhue = 0, dsat = 0, dexp = 0, flip = 0, blur = 0;
int augmentation_calculated = 0;
@ -862,6 +862,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
r3 = random_float();
r4 = random_float();
r_scale = random_float();
dhue = rand_uniform_strong(-hue, hue);
dsat = rand_scale(saturation);
dexp = rand_scale(exposure);
@ -874,6 +876,33 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
int pright = rand_precalc_random(-dw, dw, r2);
int ptop = rand_precalc_random(-dh, dh, r3);
int pbot = rand_precalc_random(-dh, dh, r4);
//printf("\n pleft = %d, pright = %d, ptop = %d, pbot = %d, ow = %d, oh = %d \n", pleft, pright, ptop, pbot, ow, oh);
float scale = rand_precalc_random(.25, 2, r_scale); // unused currently
if (letter_box)
{
float img_ar = (float)ow / (float)oh;
float net_ar = (float)w / (float)h;
float result_ar = img_ar / net_ar;
//printf(" ow = %d, oh = %d, w = %d, h = %d, img_ar = %f, net_ar = %f, result_ar = %f \n", ow, oh, w, h, img_ar, net_ar, result_ar);
if (result_ar > 1) // sheight - should be increased
{
float oh_tmp = ow / net_ar;
float delta_h = (oh_tmp - oh)/2;
ptop = ptop - delta_h;
pbot = pbot - delta_h;
//printf(" result_ar = %f, oh_tmp = %f, delta_h = %d, ptop = %f, pbot = %f \n", result_ar, oh_tmp, delta_h, ptop, pbot);
}
else // swidth - should be increased
{
float ow_tmp = oh * net_ar;
float delta_w = (ow_tmp - ow)/2;
pleft = pleft - delta_w;
pright = pright - delta_w;
//printf(" result_ar = %f, ow_tmp = %f, delta_w = %d, pleft = %f, pright = %f \n", result_ar, ow_tmp, delta_w, pleft, pright);
}
}
int swidth = ow - pleft - pright;
int sheight = oh - ptop - pbot;
@ -884,9 +913,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
float dx = ((float)pleft / ow) / sx;
float dy = ((float)ptop / oh) / sy;
fill_truth_detection(filename, boxes, truth, classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp,
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, dhue, dsat, dexp,
blur, boxes, d.y.vals[i]);
if (i_mixup) {
@ -947,7 +977,7 @@ void blend_images(image new_img, float alpha, image old_img, float beta)
}
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs)
{
const int random_index = random_gen();
c = c ? c : 3;
@ -971,7 +1001,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*c;
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
float r1 = 0, r2 = 0, r3 = 0, r4 = 0, r_scale;
float dhue = 0, dsat = 0, dexp = 0, flip = 0;
int augmentation_calculated = 0;
@ -999,6 +1029,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
r3 = random_float();
r4 = random_float();
r_scale = random_float();
dhue = rand_uniform_strong(-hue, hue);
dsat = rand_scale(saturation);
dexp = rand_scale(exposure);
@ -1011,6 +1043,32 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
int ptop = rand_precalc_random(-dh, dh, r3);
int pbot = rand_precalc_random(-dh, dh, r4);
float scale = rand_precalc_random(.25, 2, r_scale); // unused currently
if (letter_box)
{
float img_ar = (float)ow / (float)oh;
float net_ar = (float)w / (float)h;
float result_ar = img_ar / net_ar;
//printf(" ow = %d, oh = %d, w = %d, h = %d, img_ar = %f, net_ar = %f, result_ar = %f \n", ow, oh, w, h, img_ar, net_ar, result_ar);
if (result_ar > 1) // sheight - should be increased
{
float oh_tmp = ow / net_ar;
float delta_h = (oh_tmp - oh) / 2;
ptop = ptop - delta_h;
pbot = pbot - delta_h;
//printf(" result_ar = %f, oh_tmp = %f, delta_h = %d, ptop = %f, pbot = %f \n", result_ar, oh_tmp, delta_h, ptop, pbot);
}
else // swidth - should be increased
{
float ow_tmp = oh * net_ar;
float delta_w = (ow_tmp - ow) / 2;
pleft = pleft - delta_w;
pright = pright - delta_w;
//printf(" result_ar = %f, ow_tmp = %f, delta_w = %d, pleft = %f, pright = %f \n", result_ar, ow_tmp, delta_w, pleft, pright);
}
}
int swidth = ow - pleft - pright;
int sheight = oh - ptop - pbot;
@ -1100,7 +1158,7 @@ void *load_thread(void *ptr)
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.mixup, a.jitter,
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs);
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.letter_box, a.show_imgs);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
} else if (a.type == COMPARE_DATA){

@ -86,8 +86,8 @@ void print_letters(float *pred, int n);
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs);
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup,
float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs);
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_super(char **paths, int n, int m, int w, int h, int scale);

@ -133,6 +133,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
args.exposure = net.exposure;
args.saturation = net.saturation;
args.hue = net.hue;
args.letter_box = net.letter_box;
if (dont_show && show_imgs) show_imgs = 2;
args.show_imgs = show_imgs;
@ -275,7 +276,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//network net_combined = combine_train_valid_networks(net, net_map);
iter_map = i;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, &net_map);// &net_combined);
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
if (mean_average_precision > best_map) {
best_map = mean_average_precision;
@ -660,7 +661,7 @@ int detections_comparator(const void *pa, const void *pb)
return 0;
}
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net)
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net)
{
int j;
list *options = read_data_cfg(datacfg);
@ -733,8 +734,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
args.w = net.w;
args.h = net.h;
args.c = net.c;
args.type = IMAGE_DATA;
//args.type = LETTERBOX_DATA;
if (letter_box) args.type = LETTERBOX_DATA;
else args.type = IMAGE_DATA;
//const float thresh_calc_avg_iou = 0.24;
float avg_iou = 0;
@ -783,14 +784,12 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
float hier_thresh = 0;
detection *dets;
if (args.type == LETTERBOX_DATA) {
int letterbox = 1;
dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letter_box);
}
else {
int letterbox = 0;
dets = get_network_boxes(&net, 1, 1, thresh, hier_thresh, 0, 0, &nboxes, letterbox);
dets = get_network_boxes(&net, 1, 1, thresh, hier_thresh, 0, 0, &nboxes, letter_box);
}
//detection *dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letterbox); // for letterbox=1
//detection *dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letter_box); // for letter_box=1
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
char labelpath[4096];
@ -1486,7 +1485,7 @@ void run_detector(int argc, char **argv)
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs);
else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, NULL);
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, letter_box, NULL);
else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);
else if (0 == strcmp(argv[2], "demo")) {
list *options = read_data_cfg(datacfg);

@ -1137,7 +1137,7 @@ static box float_to_box_stride(float *f, int stride)
image image_data_augmentation(mat_cv* mat, int w, int h,
int pleft, int ptop, int swidth, int sheight, int flip,
float jitter, float dhue, float dsat, float dexp,
float dhue, float dsat, float dexp,
int blur, int num_boxes, float *truth)
{
image out;

@ -95,7 +95,7 @@ void draw_train_loss(mat_cv* img, int img_size, float avg_loss, float max_img_lo
// Data augmentation
image image_data_augmentation(mat_cv* mat, int w, int h,
int pleft, int ptop, int swidth, int sheight, int flip,
float jitter, float dhue, float dsat, float dexp,
float dhue, float dsat, float dexp,
int blur, int num_boxes, float *truth);
// blend two images with (alpha and beta)

@ -153,7 +153,7 @@ float get_network_cost(network net);
//LIB_API network *load_network_custom(char *cfg, char *weights, int clear, int batch);
//LIB_API network *load_network(char *cfg, char *weights, int clear);
//LIB_API float *network_predict_image(network *net, image im);
//LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net);
//LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, int map_points, int letter_box, network *existing_net);
//LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port);
//LIB_API int network_width(network *net);
//LIB_API int network_height(network *net);

@ -738,6 +738,7 @@ void parse_net_options(list *options, network *net)
net->flip = option_find_int_quiet(options, "flip", 1);
net->blur = option_find_int_quiet(options, "blur", 0);
net->mixup = option_find_int_quiet(options, "mixup", 0);
net->letter_box = option_find_int_quiet(options, "letter_box", 0);
net->angle = option_find_float_quiet(options, "angle", 0);
net->aspect = option_find_float_quiet(options, "aspect", 1);

Loading…
Cancel
Save