From 6c54f5ffb240f0c97b7d33a55ce852716d422919 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Thu, 7 Jun 2018 02:20:44 +0300 Subject: [PATCH] Show inconsistent information if it is present in .cfg and .names files --- src/data.c | 8 +++++++- src/data.h | 1 + src/detector.c | 8 +++++++- src/network.c | 6 ++++++ 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/data.c b/src/data.c index c98ede9b..aae86670 100644 --- a/src/data.c +++ b/src/data.c @@ -512,14 +512,20 @@ matrix load_tags_paths(char **paths, int n, int k) return y; } -char **get_labels(char *filename) +char **get_labels_custom(char *filename, int *size) { list *plist = get_paths(filename); + if(size) *size = plist->size; char **labels = (char **)list_to_array(plist); free_list(plist); return labels; } +char **get_labels(char *filename) +{ + return get_labels_custom(filename, NULL); +} + void free_data(data d) { if(!d.shallow){ diff --git a/src/data.h b/src/data.h index b46143fb..f7ab585c 100644 --- a/src/data.h +++ b/src/data.h @@ -100,6 +100,7 @@ data load_data_writing(char **paths, int n, int m, int w, int h, int out_w, int list *get_paths(char *filename); char **get_labels(char *filename); +char **get_labels_custom(char *filename, int *size); void get_random_batch(data d, int n, float *X, float *y); data get_data_part(data d, int part, int total); data get_random_data(data d, int num); diff --git a/src/detector.c b/src/detector.c index e099e91a..0c0b14d8 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1069,7 +1069,8 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam { list *options = read_data_cfg(datacfg); char *name_list = option_find_str(options, "names", "data/names.list"); - char **names = get_labels(name_list); + int names_size = 0; + char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list); image **alphabet = load_alphabet(); network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1 @@ -1078,6 +1079,11 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam } //set_batch_network(&net, 1); fuse_conv_batchnorm(net); + if (net.layers[net.n - 1].classes != names_size) { + printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n", + name_list, names_size, net.layers[net.n - 1].classes, datacfg); + if(net.layers[net.n - 1].classes > names_size) getchar(); + } srand(2222222); double time; char buff[256]; diff --git a/src/network.c b/src/network.c index 1498a2ad..b32dfc00 100644 --- a/src/network.c +++ b/src/network.c @@ -602,12 +602,18 @@ void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, f void fill_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter) { + int prev_classes = -1; int j; for (j = 0; j < net->n; ++j) { layer l = net->layers[j]; if (l.type == YOLO) { int count = get_yolo_detections(l, w, h, net->w, net->h, thresh, map, relative, dets, letter); dets += count; + if (prev_classes < 0) prev_classes = l.classes; + else if (prev_classes != l.classes) { + printf(" Error: Different [yolo] layers have different number of classes = %d and %d - check your cfg-file! \n", + prev_classes, l.classes); + } } if (l.type == REGION) { custom_get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets, letter);