Added mosaic=1 data augmentation for Classifier

pull/6241/head
AlexeyAB 6 years ago
parent acb578cb37
commit e5d464d3d0
  1. 3
      build/darknet/x64/cfg/efficientnet_b0.cfg
  2. 3
      cfg/efficientnet_b0.cfg
  3. 64
      src/data.c
  4. 4
      src/parser.c

@ -11,8 +11,9 @@ channels=3
momentum=0.9
decay=0.0005
max_crop=256
#mixup=1
#mixup=4
cutmix=1
mosaic=1
burn_in=1000
#burn_in=100

@ -11,8 +11,9 @@ channels=3
momentum=0.9
decay=0.0005
max_crop=256
#mixup=1
#mixup=4
cutmix=1
mosaic=1
burn_in=1000
#burn_in=100

@ -1323,7 +1323,6 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
d.shallow = 0;
d.X = load_image_augment_paths(paths, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d.y = load_labels_paths(paths, n, labels, k, hierarchy);
if(m) free(paths);
if (mixup && rand_int(0, 1)) {
char **paths_mix = get_random_paths(paths_stored, n, m);
@ -1331,14 +1330,32 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
d2.shallow = 0;
d2.X = load_image_augment_paths(paths_mix, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy);
free(paths_mix);
data d3 = { 0 };
d3.shallow = 0;
data d4 = { 0 };
d4.shallow = 0;
if (mixup >= 3) {
char **paths_mix3 = get_random_paths(paths_stored, n, m);
d3.X = load_image_augment_paths(paths_mix3, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy);
free(paths_mix3);
char **paths_mix4 = get_random_paths(paths_stored, n, m);
d4.X = load_image_augment_paths(paths_mix4, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy);
free(paths_mix4);
}
// mix
int i, j;
for (i = 0; i < d2.X.rows; ++i) {
if (mixup == 3) mixup = rand_int(1, 2); // alternate MixUp and CutMix
if (mixup == 4) mixup = rand_int(2, 3); // alternate MixUp and CutMix
// MixUp
// MixUp -----------------------------------
if (mixup == 1) {
// mix images
for (j = 0; j < d2.X.cols; ++j) {
@ -1350,8 +1367,8 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
d.y.vals[i][j] = (d.y.vals[i][j] + d2.y.vals[i][j]) / 2.0f;
}
}
// CutMix
else {
// CutMix -----------------------------------
else if (mixup == 2) {
const float min = 0.3; // 0.3*0.3 = 9%
const float max = 0.8; // 0.8*0.8 = 64%
const int cut_w = rand_int(w*min, w*max);
@ -1382,10 +1399,43 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
d.y.vals[i][j] = d.y.vals[i][j] * beta + d2.y.vals[i][j] * alpha;
}
}
// Mosaic -----------------------------------
else if (mixup == 3)
{
const float min_offset = 0.2; // 20%
const int cut_x = rand_int(w*min_offset, w*(1 - min_offset));
const int cut_y = rand_int(h*min_offset, h*(1 - min_offset));
float s1 = (float)(cut_x * cut_y) / (w*h);
float s2 = (float)((w - cut_x) * cut_y) / (w*h);
float s3 = (float)(cut_x * (h - cut_y)) / (w*h);
float s4 = (float)((w - cut_x) * (h - cut_y)) / (w*h);
int c, x, y;
for (c = 0; c < 3; ++c) {
for (y = 0; y < h; ++y) {
for (x = 0; x < w; ++x) {
int j = x + y*w + c*w*h;
if (x < cut_x && y < cut_y) d.X.vals[i][j] = d.X.vals[i][j];
if (x >= cut_x && y < cut_y) d.X.vals[i][j] = d2.X.vals[i][j];
if (x < cut_x && y >= cut_y) d.X.vals[i][j] = d3.X.vals[i][j];
if (x >= cut_x && y >= cut_y) d.X.vals[i][j] = d4.X.vals[i][j];
}
}
}
for (j = 0; j < d.y.cols; ++j) {
d.y.vals[i][j] = d.y.vals[i][j] * s1 + d2.y.vals[i][j] * s2 + d3.y.vals[i][j] * s3 + d4.y.vals[i][j] * s4;
}
}
}
free_data(d2);
free(paths_mix);
if (mixup == 3) {
free_data(d3);
free_data(d4);
}
}
if (show_imgs) {
@ -1416,6 +1466,8 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
printf("\nYou use flag -show_imgs, so will be saved aug_...jpg images. Click on window and press ESC button \n");
}
if (m) free(paths);
return d;
}

@ -923,8 +923,10 @@ void parse_net_options(list *options, network *net)
net->blur = option_find_int_quiet(options, "blur", 0);
net->mixup = option_find_int_quiet(options, "mixup", 0);
int cutmix = option_find_int_quiet(options, "cutmix", 0);
if (net->mixup && cutmix) net->mixup = 3;
int mosaic = option_find_int_quiet(options, "mosaic", 0);
if (mosaic && cutmix) net->mixup = 4;
else if (cutmix) net->mixup = 2;
else if (mosaic) net->mixup = 3;
net->letter_box = option_find_int_quiet(options, "letter_box", 0);
net->angle = option_find_float_quiet(options, "angle", 0);

Loading…
Cancel
Save