From 9ff40baa4f0b005d50c4e3302b4103551160943f Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Thu, 8 Feb 2018 01:04:40 +0300 Subject: [PATCH] Fixes for training Yolo on small objects --- src/data.c | 17 ++++++++++++----- src/data.h | 1 + src/detector.c | 1 + src/layer.h | 1 + src/parser.c | 1 + 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/data.c b/src/data.c index ad9ef8bf..71781aaf 100644 --- a/src/data.c +++ b/src/data.c @@ -292,7 +292,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int free(boxes); } -void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy) +void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy, int small_object) { char labelpath[4096]; find_replace(path, "images", "labels", labelpath); @@ -305,6 +305,12 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, find_replace(labelpath, ".JPEG", ".txt", labelpath); int count = 0; box_label *boxes = read_boxes(labelpath, &count); + if (small_object == 1) { + for (int i = 0; i < count; ++i) { + if (boxes[i].w < 0.01) boxes[i].w = 0.01; + if (boxes[i].h < 0.01) boxes[i].h = 0.01; + } + } randomize_boxes(boxes, count); correct_boxes(boxes, count, dx, dy, sx, sy, flip); if(count > num_boxes) count = num_boxes; @@ -319,7 +325,8 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, h = boxes[i].h; id = boxes[i].id; - if ((w < .01 || h < .01)) continue; + // not detect small objects + if ((w < 0.01 || h < 0.01)) continue; truth[i*5+0] = x; truth[i*5+1] = y; @@ -661,7 +668,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter) return d; } -data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure) +data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure, int small_object) { char **random_paths = get_random_paths(paths, n, m); int i; @@ -704,7 +711,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in random_distort_image(sized, hue, saturation, exposure); d.X.vals[i] = sized.data; - fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy); + fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, small_object); free_image(orig); free_image(cropped); @@ -734,7 +741,7 @@ void *load_thread(void *ptr) } else if (a.type == REGION_DATA){ *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure); } else if (a.type == DETECTION_DATA){ - *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure); + *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure, a.small_object); } else if (a.type == SWAG_DATA){ *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter); } else if (a.type == COMPARE_DATA){ diff --git a/src/data.h b/src/data.h index 618d89cf..7838681a 100644 --- a/src/data.h +++ b/src/data.h @@ -53,6 +53,7 @@ typedef struct load_args{ int classes; int background; int scale; + int small_object; float jitter; float angle; float aspect; diff --git a/src/detector.c b/src/detector.c index d79fbcce..4e220559 100644 --- a/src/detector.c +++ b/src/detector.c @@ -82,6 +82,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i args.classes = classes; args.jitter = jitter; args.num_boxes = l.max_boxes; + args.small_object = l.small_object; args.d = &buffer; args.type = DETECTION_DATA; args.threads = 4;// 8; diff --git a/src/layer.h b/src/layer.h index 285abe3e..db012f1f 100644 --- a/src/layer.h +++ b/src/layer.h @@ -63,6 +63,7 @@ struct layer{ int out_h, out_w, out_c; int n; int max_boxes; + int small_object; int groups; int size; int side; diff --git a/src/parser.c b/src/parser.c index 2a1ea3b1..67b4bfb2 100644 --- a/src/parser.c +++ b/src/parser.c @@ -245,6 +245,7 @@ layer parse_region(list *options, size_params params) l.log = option_find_int_quiet(options, "log", 0); l.sqrt = option_find_int_quiet(options, "sqrt", 0); + l.small_object = option_find_int(options, "small_object", 0); l.softmax = option_find_int(options, "softmax", 0); l.max_boxes = option_find_int_quiet(options, "max",30); l.jitter = option_find_float(options, "jitter", .2);