diff --git a/src/detector.c b/src/detector.c index ce259fd6..d860a6ad 100644 --- a/src/detector.c +++ b/src/detector.c @@ -552,9 +552,14 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) args.h = net.h; args.type = IMAGE_DATA; + const float thresh_calc_avg_iou = 0.24; + float avg_iou = 0; + int tp_for_thresh = 0; + int fp_for_thresh = 0; + box_prob *detections = calloc(1, sizeof(box_prob)); int detections_count = 0; - int unique_truth_index = 0; + int unique_truth_count = 0; int *truth_classes_count = calloc(classes, sizeof(int)); @@ -642,7 +647,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) if (current_iou > iou_thresh && class_id == truth[j].id) { if (current_iou > max_iou) { max_iou = current_iou; - truth_index = unique_truth_index + j; + truth_index = unique_truth_count + j; } } } @@ -659,14 +664,25 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) float current_iou = box_iou(boxes[i], t); if (current_iou > iou_thresh && class_id == truth_dif[j].id) { --detections_count; + break; } } } + + // calc avg IoU, true-positives, false-positives for required Threshold + if (prob > thresh_calc_avg_iou) { + if (truth_index > -1) { + avg_iou += max_iou; + ++tp_for_thresh; + } + else + fp_for_thresh++; + } } } } - unique_truth_index += num_labels; + unique_truth_count += num_labels; free(id); free_image(val[t]); @@ -674,6 +690,8 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) } } + avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh); + // SORT(detections) qsort(detections, detections_count, sizeof(box_prob), detections_comparator); @@ -689,10 +707,10 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) for (i = 0; i < classes; ++i) { pr[i] = calloc(detections_count, sizeof(pr_t)); } - printf("detections_count = %d, unique_truth_index = %d \n", detections_count, unique_truth_index); + printf("detections_count = %d, unique_truth_count = %d \n", detections_count, unique_truth_count); - int *truth_flags = calloc(unique_truth_index, sizeof(int)); + int *truth_flags = calloc(unique_truth_count, sizeof(int)); int rank; for (rank = 0; rank < detections_count; ++rank) { @@ -763,6 +781,9 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile) mean_average_precision += avg_precision; } + printf(" for thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n", + thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100); + mean_average_precision = mean_average_precision / classes; printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision*100);