mAP calculation during training, if is used flag -map

pull/2095/head
AlexeyAB 7 years ago
parent 742bb7c7ce
commit dc7f8a32ae
  1. 4
      src/classifier.c
  2. 78
      src/detector.c
  3. 34
      src/image.c
  4. 45
      src/maxpool_layer.c
  5. 5
      src/maxpool_layer.h
  6. 3
      src/network.c
  7. 2
      src/yolo_layer.c

@ -25,7 +25,7 @@ image get_image_from_stream_cpp(CvCapture *cap);
#include "http_stream.h"
IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_lines, int img_size);
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches);
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches, float precision);
#endif
@ -153,7 +153,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
#ifdef OPENCV
if(!dont_show)
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches);
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches, -1);
#endif // OPENCV
if (i >= (iter_save + 100)) {

@ -26,7 +26,7 @@
#endif
IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_lines, int img_size);
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches);
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches, float precision);
#define CV_RGB(r, g, b) cvScalar( (b), (g), (r), 0 )
#endif // OPENCV
@ -37,12 +37,29 @@ int check_mistakes;
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show)
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net);
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map)
{
list *options = read_data_cfg(datacfg);
char *train_images = option_find_str(options, "train", "data/train.list");
char *train_images = option_find_str(options, "train", "data/train.txt");
char *valid_images = option_find_str(options, "valid", train_images);
char *backup_directory = option_find_str(options, "backup", "/backup/");
int valid_images_num = 0;
if (calc_map) {
FILE* valid_file = fopen(valid_images, "r");
if (!valid_file) {
printf("\n Error: There is no %s file for mAP calculation!\n Don't use -map flag.\n Or set valid=%s in your %s file. \n", valid_images, train_images, datacfg);
getchar();
exit(-1);
}
else fclose(valid_file);
list *plist = get_paths(valid_images);
valid_images_num = plist->size;
free_list(plist);
}
srand(time(0));
char *base = basecfg(cfgfile);
printf("%s\n", base);
@ -91,9 +108,11 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int init_w = net.w;
int init_h = net.h;
int iter_save, iter_save_last;
int iter_save, iter_save_last, iter_map;
iter_save = get_current_batch(net);
iter_save_last = get_current_batch(net);
iter_map = get_current_batch(net);
float mean_average_precision = -1;
load_args args = {0};
args.w = net.w;
@ -205,8 +224,15 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now()-time), i*imgs);
#ifdef OPENCV
if(!dont_show)
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches);
if (!dont_show) {
if (calc_map && (i >= (iter_map + valid_images_num/10) || i == net.max_batches) && i >= 1000) {
iter_map = i;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net);
printf("\n mean_average_precision = %f \n", mean_average_precision);
}
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches, mean_average_precision);
}
#endif // OPENCV
//if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) {
@ -567,7 +593,7 @@ int detections_comparator(const void *pa, const void *pb)
return 0;
}
void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh)
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net)
{
int j;
list *options = read_data_cfg(datacfg);
@ -580,14 +606,26 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
if (mapf) map = read_map(mapf);
FILE* reinforcement_fd = NULL;
network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
network net;
int initial_batch;
if (existing_net) {
char *train_images = option_find_str(options, "train", "data/train.txt");
char *valid_images = option_find_str(options, "valid", train_images);
net = *existing_net;
initial_batch = net.batch;
set_batch_network(&net, 1);
}
else {
net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}
//set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
}
srand(time(0));
printf("\n calculation mAP (mean average precision)...\n");
list *plist = get_paths(valid_images);
char **paths = (char **)list_to_array(plist);
@ -611,6 +649,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
//const float iou_thresh = 0.5;
int nthreads = 4;
if (m < 4) nthreads = m;
image *val = calloc(nthreads, sizeof(image));
image *val_resized = calloc(nthreads, sizeof(image));
image *buf = calloc(nthreads, sizeof(image));
@ -643,7 +682,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
}
time_t start = time(0);
for (i = nthreads; i < m + nthreads; i += nthreads) {
fprintf(stderr, "%d\n", i);
fprintf(stderr, "\r%d", i);
for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
pthread_join(thr[t], 0);
val[t] = buf[t];
@ -803,7 +842,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
for (i = 0; i < classes; ++i) {
pr[i] = calloc(detections_count, sizeof(pr_t));
}
printf("detections_count = %d, unique_truth_count = %d \n", detections_count, unique_truth_count);
printf("\n detections_count = %d, unique_truth_count = %d \n", detections_count, unique_truth_count);
int *truth_flags = calloc(unique_truth_count, sizeof(int));
@ -904,6 +943,20 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
if (reinforcement_fd != NULL) fclose(reinforcement_fd);
// free memory
free_ptrs(names, net.layers[net.n - 1].classes);
free_list_contents_kvp(options);
free_list(options);
if (existing_net) {
set_batch_network(&net, initial_batch);
}
else {
free_network(net);
}
return mean_average_precision;
}
#ifdef OPENCV
@ -1245,6 +1298,7 @@ void run_detector(int argc, char **argv)
{
int dont_show = find_arg(argc, argv, "-dont_show");
int show = find_arg(argc, argv, "-show");
int calc_map = find_arg(argc, argv, "-map");
check_mistakes = find_arg(argc, argv, "-check_mistakes");
int http_stream_port = find_int_arg(argc, argv, "-http_port", -1);
char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
@ -1299,10 +1353,10 @@ void run_detector(int argc, char **argv)
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
char *filename = (argc > 6) ? argv[6]: 0;
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh);
else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, NULL);
else if(0==strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);
else if(0==strcmp(argv[2], "demo")) {
list *options = read_data_cfg(datacfg);

@ -704,6 +704,8 @@ IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_li
cvLine(img, pt1, pt2, CV_RGB(128, 128, 128), 1, 8, 0);
}
}
cvPutText(img, "Loss", cvPoint(0, 35), &font, CV_RGB(0, 0, 255));
cvPutText(img, "Iteration number", cvPoint(draw_size / 2, img_size - 10), &font, CV_RGB(0, 0, 0));
char max_batches_buff[100];
sprintf(max_batches_buff, "in cfg max_batches=%d", max_batches);
@ -718,7 +720,7 @@ IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_li
return img;
}
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches)
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches, float precision)
{
int img_offset = 50;
int draw_size = img_size - img_offset;
@ -731,12 +733,40 @@ void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_
if (pt1.y < 0) pt1.y = 1;
cvCircle(img, pt1, 1, CV_RGB(0, 0, 255), CV_FILLED, 8, 0);
// precision
if (precision >= 0) {
static float old_precision = 0;
static iteration_old = 0;
if (old_precision != precision) {
cvLine(img,
cvPoint(img_offset + draw_size * (float)iteration_old / max_batches, draw_size * (1 - old_precision)),
cvPoint(img_offset + draw_size * (float)current_batch / max_batches, draw_size * (1 - precision)),
CV_RGB(255, 0, 0), 1, 8, 0);
old_precision = precision;
iteration_old = current_batch;
sprintf(char_buff, "%2.0f%% ", precision * 100);
CvFont font3;
cvInitFont(&font3, CV_FONT_HERSHEY_COMPLEX_SMALL, 0.7, 0.7, 0, 5, CV_AA);
cvPutText(img, char_buff, cvPoint(pt1.x - 30, draw_size * (1 - precision) + 15), &font3, CV_RGB(255, 255, 255));
CvFont font2;
cvInitFont(&font2, CV_FONT_HERSHEY_COMPLEX_SMALL, 0.7, 0.7, 0, 1, CV_AA);
cvPutText(img, char_buff, cvPoint(pt1.x - 30, draw_size * (1 - precision) + 15), &font2, CV_RGB(200, 0, 0));
}
cvPutText(img, "mAP%", cvPoint(0, 12), &font, CV_RGB(255, 0, 0));
}
sprintf(char_buff, "current avg loss = %2.4f iteration = %d", avg_loss, current_batch);
pt1.x = img_size / 2, pt1.y = 30;
pt1.x = 55, pt1.y = 10;
pt2.x = pt1.x + 460, pt2.y = pt1.y + 20;
cvRectangle(img, pt1, pt2, CV_RGB(255, 255, 255), CV_FILLED, 8, 0);
pt1.y += 15;
cvPutText(img, char_buff, pt1, &font, CV_RGB(0, 0, 0));
cvShowImage("average loss", img);
int k = cvWaitKey(20);
if (k == 's' || current_batch == (max_batches - 1) || current_batch % 100 == 0) {

@ -19,6 +19,32 @@ image get_maxpool_delta(maxpool_layer l)
return float_to_image(w,h,c,l.delta);
}
void cudnn_maxpool_setup(layer *l)
{
#ifdef CUDNN
cudnnStatus_t maxpool_status;
maxpool_status = cudnnCreatePoolingDescriptor(&l->poolingDesc);
maxpool_status = cudnnSetPooling2dDescriptor(
l->poolingDesc,
CUDNN_POOLING_MAX,
CUDNN_PROPAGATE_NAN, // CUDNN_PROPAGATE_NAN, CUDNN_NOT_PROPAGATE_NAN
l->size,
l->size,
0, //l.pad,
0, //l.pad,
l->stride,
l->stride);
cudnnCreateTensorDescriptor(&l->srcTensorDesc);
cudnnCreateTensorDescriptor(&l->dstTensorDesc);
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
#endif // CUDNN
}
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding)
{
maxpool_layer l = {0};
@ -47,26 +73,9 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
l.indexes_gpu = cuda_make_int_array(output_size);
l.output_gpu = cuda_make_array(l.output, output_size);
l.delta_gpu = cuda_make_array(l.delta, output_size);
#ifdef CUDNN
cudnnStatus_t maxpool_status;
maxpool_status = cudnnCreatePoolingDescriptor(&l.poolingDesc);
maxpool_status = cudnnSetPooling2dDescriptor(
l.poolingDesc,
CUDNN_POOLING_MAX,
CUDNN_PROPAGATE_NAN, // CUDNN_PROPAGATE_NAN, CUDNN_NOT_PROPAGATE_NAN
l.size,
l.size,
0, //l.pad,
0, //l.pad,
l.stride,
l.stride);
cudnn_maxpool_setup(&l);
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
#endif // CUDNN
#endif // GPU
l.bflops = (l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
fprintf(stderr, "max %d x %d / %d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);

@ -17,7 +17,10 @@ void backward_maxpool_layer(const maxpool_layer l, network_state state);
#ifdef GPU
void forward_maxpool_layer_gpu(maxpool_layer l, network_state state);
void backward_maxpool_layer_gpu(maxpool_layer l, network_state state);
#endif
#ifdef CUDNN
void cudnn_maxpool_setup(maxpool_layer *l);
#endif // CUDNN
#endif // GPU
#endif

@ -377,6 +377,9 @@ void set_batch_network(network *net, int b)
}
*/
}
else if (net->layers[i].type == MAXPOOL) {
cudnn_maxpool_setup(net->layers + i);
}
#endif
}
}

@ -364,6 +364,7 @@ void avg_flipped_yolo(layer l)
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter)
{
//printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n);
int i,j,n;
float *predictions = l.output;
if (l.batch == 2) avg_flipped_yolo(l);
@ -376,6 +377,7 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh,
float objectness = predictions[obj_index];
//if(objectness <= thresh) continue; // incorrect behavior for Nan values
if (objectness > thresh) {
//printf("\n objectness = %f, thresh = %f, i = %d, n = %d \n", objectness, thresh, i, n);
int box_index = entry_index(l, 0, n*l.w*l.h + i, 0);
dets[count].bbox = get_yolo_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h);
dets[count].objectness = objectness;

Loading…
Cancel
Save