From a1af57d8d60b50e8188f36b7f74752c8cc124177 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Thu, 15 Feb 2018 15:43:25 +0300 Subject: [PATCH] Added C implementation of calculation mAP (mean average precision) using Darknet --- build/darknet/x64/calc_mAP.cmd | 12 + .../{compute_mAP.cmd => calc_mAP_voc_py.cmd} | 4 +- src/detector.c | 284 ++++++++++++++++++ 3 files changed, 298 insertions(+), 2 deletions(-) create mode 100644 build/darknet/x64/calc_mAP.cmd rename build/darknet/x64/{compute_mAP.cmd => calc_mAP_voc_py.cmd} (72%) diff --git a/build/darknet/x64/calc_mAP.cmd b/build/darknet/x64/calc_mAP.cmd new file mode 100644 index 00000000..92f3719d --- /dev/null +++ b/build/darknet/x64/calc_mAP.cmd @@ -0,0 +1,12 @@ +rem # How to calculate mAP (mean average precision) + + +darknet.exe detector map data/voc.data tiny-yolo-voc.cfg tiny-yolo-voc.weights + + +rem darknet.exe detector map data/voc.data yolo-voc.cfg yolo-voc.weights + + + + +pause diff --git a/build/darknet/x64/compute_mAP.cmd b/build/darknet/x64/calc_mAP_voc_py.cmd similarity index 72% rename from build/darknet/x64/compute_mAP.cmd rename to build/darknet/x64/calc_mAP_voc_py.cmd index 8c5ba3cf..0267600c 100644 --- a/build/darknet/x64/compute_mAP.cmd +++ b/build/darknet/x64/calc_mAP_voc_py.cmd @@ -3,9 +3,9 @@ rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\Scripts\pip install cPi rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\Scripts\pip install _pickle -rem darknet.exe detector valid data/voc.data tiny-yolo-voc.cfg tiny-yolo-voc.weights +darknet.exe detector valid data/voc.data tiny-yolo-voc.cfg tiny-yolo-voc.weights -darknet.exe detector valid data/voc.data yolo-voc.cfg yolo-voc.weights +rem darknet.exe detector valid data/voc.data yolo-voc.cfg yolo-voc.weights reval_voc_py3.py --year 2007 --classes data\voc.names --image_set test --voc_dir E:\VOC2007_2012\VOCtrainval_11-May-2012\VOCdevkit results diff --git a/src/detector.c b/src/detector.c index 4e220559..3111c193 100644 --- a/src/detector.c +++ b/src/detector.c @@ -315,6 +315,8 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) float thresh = .005; float nms = .45; + int detection_count = 0; + int nthreads = 4; image *val = calloc(nthreads, sizeof(image)); image *val_resized = calloc(nthreads, sizeof(image)); @@ -356,6 +358,15 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) int h = val[t].h; get_region_boxes(l, w, h, thresh, probs, boxes, 0, map); if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms); + + int x, y; + for (x = 0; x < (l.w*l.h*l.n); ++x) { + for (y = 0; y < classes; ++y) + { + if (probs[x][y]) ++detection_count; + } + } + if (coco){ print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h); } else if (imagenet){ @@ -376,6 +387,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) fprintf(fp, "\n]\n"); fclose(fp); } + printf("\n detection_count = %d \n", detection_count); fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } @@ -409,6 +421,8 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) float iou_thresh = .5; float nms = .4; + int detection_count = 0, truth_count = 0; + int total = 0; int correct = 0; int proposals = 0; @@ -432,6 +446,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) int num_labels = 0; box_label *truth = read_boxes(labelpath, &num_labels); + truth_count += num_labels; for(k = 0; k < l.w*l.h*l.n; ++k){ if(probs[k][0] > thresh){ ++proposals; @@ -458,6 +473,274 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) free_image(orig); free_image(sized); } + printf("\n truth_count = %d \n", truth_count); +} + +typedef struct { + box b; + float p; + int class_id; + int image_index; + int truth_flag; + int unique_truth_index; +} box_prob; + +int detections_comparator(const void *pa, const void *pb) +{ + box_prob a = *(box_prob *)pa; + box_prob b = *(box_prob *)pb; + float diff = a.p - b.p; + if (diff < 0) return 1; + else if (diff > 0) return -1; + return 0; +} + +void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) +{ + int j; + list *options = read_data_cfg(datacfg); + char *valid_images = option_find_str(options, "valid", "data/train.list"); + char *name_list = option_find_str(options, "names", "data/names.list"); + //char *prefix = option_find_str(options, "results", "results"); + char **names = get_labels(name_list); + char *mapf = option_find_str(options, "map", 0); + int *map = 0; + if (mapf) map = read_map(mapf); + + network net = parse_network_cfg_custom(cfgfile, 1); + if (weightfile) { + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "comp4_det_test_"; + list *plist = get_paths(valid_images); + char **paths = (char **)list_to_array(plist); + + layer l = net.layers[net.n - 1]; + int classes = l.classes; + + box *boxes = calloc(l.w*l.h*l.n, sizeof(box)); + float **probs = calloc(l.w*l.h*l.n, sizeof(float *)); + for (j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i = 0; + int t; + + const float thresh = .005; + const float nms = .45; + const float iou_thresh = 0.5; + + int nthreads = 4; + image *val = calloc(nthreads, sizeof(image)); + image *val_resized = calloc(nthreads, sizeof(image)); + image *buf = calloc(nthreads, sizeof(image)); + image *buf_resized = calloc(nthreads, sizeof(image)); + pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); + + load_args args = { 0 }; + args.w = net.w; + args.h = net.h; + args.type = IMAGE_DATA; + + box_prob *detections = calloc(1, sizeof(box_prob)); + int detections_count = 0; + int unique_truth_index = 0; + + int *truth_classes_count = calloc(classes, sizeof(int)); + + for (t = 0; t < nthreads; ++t) { + args.path = paths[i + t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); + } + time_t start = time(0); + for (i = nthreads; i < m + nthreads; i += nthreads) { + fprintf(stderr, "%d\n", i); + for (t = 0; t < nthreads && i + t - nthreads < m; ++t) { + pthread_join(thr[t], 0); + val[t] = buf[t]; + val_resized[t] = buf_resized[t]; + } + for (t = 0; t < nthreads && i + t < m; ++t) { + args.path = paths[i + t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); + } + for (t = 0; t < nthreads && i + t - nthreads < m; ++t) { + const int image_index = i + t - nthreads; + char *path = paths[i + t - nthreads]; + char *id = basecfg(path); + float *X = val_resized[t].data; + network_predict(net, X); + get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, map); + if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms); + + char labelpath[4096]; + find_replace(path, "images", "labels", labelpath); + find_replace(labelpath, "JPEGImages", "labels", labelpath); + find_replace(labelpath, ".jpg", ".txt", labelpath); + find_replace(labelpath, ".JPEG", ".txt", labelpath); + find_replace(labelpath, ".png", ".txt", labelpath); + int num_labels = 0; + box_label *truth = read_boxes(labelpath, &num_labels); + int i, j; + for (j = 0; j < num_labels; ++j) { + truth_classes_count[truth[j].id]++; + } + + for (i = 0; i < (l.w*l.h*l.n); ++i) { + + int class_id; + for (class_id = 0; class_id < classes; ++class_id) { + float prob = probs[i][class_id]; + if (prob > 0) { + detections_count++; + detections = realloc(detections, detections_count * sizeof(box_prob)); + detections[detections_count - 1].b = boxes[i]; + detections[detections_count - 1].p = prob; + detections[detections_count - 1].image_index = image_index; + detections[detections_count - 1].class_id = class_id; + + int truth_index = -1; + float max_iou = 0; + for (j = 0; j < num_labels; ++j) + { + box t = { truth[j].x, truth[j].y, truth[j].w, truth[j].h }; + //printf(" IoU = %f, prob = %f, class_id = %d, truth[j].id = %d \n", + // box_iou(boxes[i], t), prob, class_id, truth[j].id); + float current_iou = box_iou(boxes[i], t); + if (current_iou > iou_thresh && class_id == truth[j].id) { + if (current_iou > max_iou) { + current_iou = max_iou; + truth_index = unique_truth_index + j; + } + } + } + // best IoU + if (truth_index > -1) { + detections[detections_count - 1].truth_flag = 1; + detections[detections_count - 1].unique_truth_index = truth_index; + } + } + } + } + + unique_truth_index += num_labels; + + free(id); + free_image(val[t]); + free_image(val_resized[t]); + } + } + + + // SORT(detections) + qsort(detections, detections_count, sizeof(box_prob), detections_comparator); + + typedef struct { + double precision; + double recall; + int tp, fp, fn; + } pr_t; + + // for PR-curve + pr_t **pr = calloc(classes, sizeof(pr_t*)); + for (i = 0; i < classes; ++i) { + pr[i] = calloc(detections_count, sizeof(pr_t)); + } + printf("detections_count = %d, unique_truth_index = %d \n", detections_count, unique_truth_index); + + + int *truth_flags = calloc(unique_truth_index, sizeof(int)); + + int rank; + for (rank = 0; rank < detections_count; ++rank) { + if(rank % 100 == 0) + printf(" rank = %d of ranks = %d \r", rank, detections_count); + + if (rank > 0) { + int class_id; + for (class_id = 0; class_id < classes; ++class_id) { + pr[class_id][rank].tp = pr[class_id][rank - 1].tp; + pr[class_id][rank].fp = pr[class_id][rank - 1].fp; + } + } + + box_prob d = detections[rank]; + // if (detected && isn't detected before) + if (d.truth_flag == 1) { + if (truth_flags[d.unique_truth_index] == 0) + { + truth_flags[d.unique_truth_index] = 1; + pr[d.class_id][rank].tp++; // true-positive + } + } + else { + pr[d.class_id][rank].fp++; // false-positive + } + + + for (i = 0; i < classes; ++i) + { + const int tp = pr[i][rank].tp; + const int fp = pr[i][rank].fp; + const int fn = truth_classes_count[i] - tp; // false-negative = objects - true-positive + pr[i][rank].fn = fn; + + if ((tp + fp) > 0) pr[i][rank].precision = (double)tp / (double)(tp + fp); + else pr[i][rank].precision = 0; + + if ((tp + fn) > 0) pr[i][rank].recall = (double)tp / (double)(tp + fn); + else pr[i][rank].recall = 0; + } + } + + free(truth_flags); + + + double mean_average_precision = 0; + + for (i = 0; i < classes; ++i) { + double avg_precision = 0; + int point; + for (point = 0; point < 11; ++point) { + double cur_recall = point * 0.1; + double cur_precision = 0; + for (rank = 0; rank < detections_count; ++rank) + { + if (pr[i][rank].recall >= cur_recall) { // > or >= + if (pr[i][rank].precision > cur_precision) { + cur_precision = pr[i][rank].precision; + } + } + } + //printf("point = %d, cur_recall = %.4f, cur_precision = %.4f \n", point, cur_recall, cur_precision); + + avg_precision += cur_precision; + } + avg_precision = avg_precision / 11; + printf("class = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision*100); + mean_average_precision += avg_precision; + } + + mean_average_precision = mean_average_precision / classes; + printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision*100); + + + for (i = 0; i < classes; ++i) { + free(pr[i]); + } + free(pr); + free(detections); + free(truth_classes_count); + + fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh) @@ -565,6 +848,7 @@ void run_detector(int argc, char **argv) else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear); else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); + else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights); else if(0==strcmp(argv[2], "demo")) { list *options = read_data_cfg(datacfg); int classes = option_find_int(options, "classes", 20);