From e5d464d3d007fad372bdd0b2152374e009998995 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Mon, 2 Dec 2019 17:41:45 +0300 Subject: [PATCH] Added mosaic=1 data augmentation for Classifier --- build/darknet/x64/cfg/efficientnet_b0.cfg | 3 +- cfg/efficientnet_b0.cfg | 3 +- src/data.c | 64 ++++++++++++++++++++--- src/parser.c | 4 +- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/build/darknet/x64/cfg/efficientnet_b0.cfg b/build/darknet/x64/cfg/efficientnet_b0.cfg index 88ec3594..c959a610 100644 --- a/build/darknet/x64/cfg/efficientnet_b0.cfg +++ b/build/darknet/x64/cfg/efficientnet_b0.cfg @@ -11,8 +11,9 @@ channels=3 momentum=0.9 decay=0.0005 max_crop=256 -#mixup=1 +#mixup=4 cutmix=1 +mosaic=1 burn_in=1000 #burn_in=100 diff --git a/cfg/efficientnet_b0.cfg b/cfg/efficientnet_b0.cfg index 88ec3594..c959a610 100644 --- a/cfg/efficientnet_b0.cfg +++ b/cfg/efficientnet_b0.cfg @@ -11,8 +11,9 @@ channels=3 momentum=0.9 decay=0.0005 max_crop=256 -#mixup=1 +#mixup=4 cutmix=1 +mosaic=1 burn_in=1000 #burn_in=100 diff --git a/src/data.c b/src/data.c index 4323b84d..90e0348c 100644 --- a/src/data.c +++ b/src/data.c @@ -1323,7 +1323,6 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d.shallow = 0; d.X = load_image_augment_paths(paths, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); d.y = load_labels_paths(paths, n, labels, k, hierarchy); - if(m) free(paths); if (mixup && rand_int(0, 1)) { char **paths_mix = get_random_paths(paths_stored, n, m); @@ -1331,14 +1330,32 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d2.shallow = 0; d2.X = load_image_augment_paths(paths_mix, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy); + free(paths_mix); + + data d3 = { 0 }; + d3.shallow = 0; + data d4 = { 0 }; + d4.shallow = 0; + if (mixup >= 3) { + char **paths_mix3 = get_random_paths(paths_stored, n, m); + d3.X = load_image_augment_paths(paths_mix3, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); + d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy); + free(paths_mix3); + + char **paths_mix4 = get_random_paths(paths_stored, n, m); + d4.X = load_image_augment_paths(paths_mix4, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); + d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy); + free(paths_mix4); + } + // mix int i, j; for (i = 0; i < d2.X.rows; ++i) { - if (mixup == 3) mixup = rand_int(1, 2); // alternate MixUp and CutMix + if (mixup == 4) mixup = rand_int(2, 3); // alternate MixUp and CutMix - // MixUp + // MixUp ----------------------------------- if (mixup == 1) { // mix images for (j = 0; j < d2.X.cols; ++j) { @@ -1350,8 +1367,8 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d.y.vals[i][j] = (d.y.vals[i][j] + d2.y.vals[i][j]) / 2.0f; } } - // CutMix - else { + // CutMix ----------------------------------- + else if (mixup == 2) { const float min = 0.3; // 0.3*0.3 = 9% const float max = 0.8; // 0.8*0.8 = 64% const int cut_w = rand_int(w*min, w*max); @@ -1382,10 +1399,43 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d.y.vals[i][j] = d.y.vals[i][j] * beta + d2.y.vals[i][j] * alpha; } } + // Mosaic ----------------------------------- + else if (mixup == 3) + { + const float min_offset = 0.2; // 20% + const int cut_x = rand_int(w*min_offset, w*(1 - min_offset)); + const int cut_y = rand_int(h*min_offset, h*(1 - min_offset)); + + float s1 = (float)(cut_x * cut_y) / (w*h); + float s2 = (float)((w - cut_x) * cut_y) / (w*h); + float s3 = (float)(cut_x * (h - cut_y)) / (w*h); + float s4 = (float)((w - cut_x) * (h - cut_y)) / (w*h); + + int c, x, y; + for (c = 0; c < 3; ++c) { + for (y = 0; y < h; ++y) { + for (x = 0; x < w; ++x) { + int j = x + y*w + c*w*h; + if (x < cut_x && y < cut_y) d.X.vals[i][j] = d.X.vals[i][j]; + if (x >= cut_x && y < cut_y) d.X.vals[i][j] = d2.X.vals[i][j]; + if (x < cut_x && y >= cut_y) d.X.vals[i][j] = d3.X.vals[i][j]; + if (x >= cut_x && y >= cut_y) d.X.vals[i][j] = d4.X.vals[i][j]; + } + } + } + + for (j = 0; j < d.y.cols; ++j) { + d.y.vals[i][j] = d.y.vals[i][j] * s1 + d2.y.vals[i][j] * s2 + d3.y.vals[i][j] * s3 + d4.y.vals[i][j] * s4; + } + } } free_data(d2); - free(paths_mix); + + if (mixup == 3) { + free_data(d3); + free_data(d4); + } } if (show_imgs) { @@ -1416,6 +1466,8 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h printf("\nYou use flag -show_imgs, so will be saved aug_...jpg images. Click on window and press ESC button \n"); } + if (m) free(paths); + return d; } diff --git a/src/parser.c b/src/parser.c index 8b1fc040..f1abe89d 100644 --- a/src/parser.c +++ b/src/parser.c @@ -923,8 +923,10 @@ void parse_net_options(list *options, network *net) net->blur = option_find_int_quiet(options, "blur", 0); net->mixup = option_find_int_quiet(options, "mixup", 0); int cutmix = option_find_int_quiet(options, "cutmix", 0); - if (net->mixup && cutmix) net->mixup = 3; + int mosaic = option_find_int_quiet(options, "mosaic", 0); + if (mosaic && cutmix) net->mixup = 4; else if (cutmix) net->mixup = 2; + else if (mosaic) net->mixup = 3; net->letter_box = option_find_int_quiet(options, "letter_box", 0); net->angle = option_find_float_quiet(options, "angle", 0);