Take TopK from obj.data file for Classifier

pull/4302/head
AlexeyAB 6 years ago
parent 3652d7d374
commit c516b6cb0a
  1. 9
      src/classifier.c
  2. 13
      src/parser.c

@ -63,6 +63,9 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *train_list = option_find_str(options, "train", "data/train.list");
int classes = option_find_int(options, "classes", 2);
int topk_data = option_find_int(options, "top", 5);
char topk_buff[10];
sprintf(topk_buff, "top%d", topk_data);
char **labels = get_labels(label_list);
list *plist = get_paths(train_list);
@ -157,14 +160,14 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
int draw_precision = 0;
if (calc_topk && (i >= calc_topk_for_each || i == net.max_batches)) {
iter_topk = i;
topk = validate_classifier_single(datacfg, cfgfile, weightfile, &net, 5); // calc TOP5
printf("\n accuracy TOP5 = %f \n", topk);
topk = validate_classifier_single(datacfg, cfgfile, weightfile, &net, topk_data); // calc TOP5
printf("\n accuracy %s = %f \n", topk_buff, topk);
draw_precision = 1;
}
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/ train_images_num, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
#ifdef OPENCV
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches, topk, draw_precision, "top5", dont_show, mjpeg_port);
draw_train_loss(img, img_size, avg_loss, max_img_loss, i, net.max_batches, topk, draw_precision, topk_buff, dont_show, mjpeg_port);
#endif // OPENCV
if (i >= (iter_save + 1000)) {

@ -801,11 +801,14 @@ route_layer parse_route(list *options, size_params params)
layer.h = first.h;
layer.c = layer.out_c;
if (n > 3) fprintf(stderr, " \t ");
else if (n > 1) fprintf(stderr, " \t ");
else fprintf(stderr, " \t\t ");
fprintf(stderr, " -> %4d x%4d x%4d \n", layer.w, layer.h, layer.c, layer.out_w, layer.out_h, layer.out_c);
if (n > 3) fprintf(stderr, " \t ");
else if (n > 1) fprintf(stderr, " \t ");
else fprintf(stderr, " \t\t ");
fprintf(stderr, " ");
if (layer.groups > 1) fprintf(stderr, "%d/%d", layer.group_id, layer.groups);
else fprintf(stderr, " ");
fprintf(stderr, " -> %4d x%4d x%4d \n", layer.out_w, layer.out_h, layer.out_c);
return layer;
}

Loading…
Cancel
Save