Fixed mAP chart during training for CUDNN=1

pull/2111/head
AlexeyAB 6 years ago
parent 78d1ade380
commit f38d060137
  1. 225
      src/detector.c
  2. 2
      src/maxpool_layer.c
  3. 1
      src/network.h

@ -35,9 +35,7 @@ void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_
int check_mistakes; int check_mistakes;
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90}; static int coco_ids[] = { 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 };
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net);
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map) void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map)
{ {
@ -46,20 +44,6 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
char *valid_images = option_find_str(options, "valid", train_images); char *valid_images = option_find_str(options, "valid", train_images);
char *backup_directory = option_find_str(options, "backup", "/backup/"); char *backup_directory = option_find_str(options, "backup", "/backup/");
int train_images_num = 0;
if (calc_map) {
FILE* valid_file = fopen(valid_images, "r");
if (!valid_file) {
printf("\n Error: There is no %s file for mAP calculation!\n Don't use -map flag.\n Or set valid=%s in your %s file. \n", valid_images, train_images, datacfg);
getchar();
exit(-1);
}
else fclose(valid_file);
list *plist = get_paths(train_images);
train_images_num = plist->size;
free_list(plist);
}
srand(time(0)); srand(time(0));
char *base = basecfg(cfgfile); char *base = basecfg(cfgfile);
printf("%s\n", base); printf("%s\n", base);
@ -69,28 +53,65 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
srand(time(0)); srand(time(0));
int seed = rand(); int seed = rand();
int i; int i;
for(i = 0; i < ngpus; ++i){ for (i = 0; i < ngpus; ++i) {
srand(seed); srand(seed);
#ifdef GPU #ifdef GPU
cuda_set_device(gpus[i]); cuda_set_device(gpus[i]);
#endif #endif
nets[i] = parse_network_cfg(cfgfile); nets[i] = parse_network_cfg(cfgfile);
if(weightfile){ if (weightfile) {
load_weights(&nets[i], weightfile); load_weights(&nets[i], weightfile);
} }
if(clear) *nets[i].seen = 0; if (clear) *nets[i].seen = 0;
nets[i].learning_rate *= ngpus; nets[i].learning_rate *= ngpus;
} }
srand(time(0)); srand(time(0));
network net = nets[0]; network net = nets[0];
int train_images_num = 0;
network net_map;
if (calc_map) {
FILE* valid_file = fopen(valid_images, "r");
if (!valid_file) {
printf("\n Error: There is no %s file for mAP calculation!\n Don't use -map flag.\n Or set valid=%s in your %s file. \n", valid_images, train_images, datacfg);
getchar();
exit(-1);
}
else fclose(valid_file);
list *plist = get_paths(train_images);
train_images_num = plist->size;
free_list(plist);
printf(" Prepare additional network for mAP calculation...\n");
net_map = parse_network_cfg_custom(cfgfile, 1);
int k;
for (k = 0; k < net.n; ++k) {
layer l = net.layers[k];
if (l.type == CONVOLUTIONAL) {
net_map.layers[k].biases = l.biases;
net_map.layers[k].scales = l.scales;
net_map.layers[k].rolling_mean = l.rolling_mean;
net_map.layers[k].rolling_variance = l.rolling_variance;
net_map.layers[k].weights = l.weights;
net_map.layers[k].biases_gpu = l.biases_gpu;
net_map.layers[k].scales_gpu = l.scales_gpu;
net_map.layers[k].rolling_mean_gpu = l.rolling_mean_gpu;
net_map.layers[k].rolling_variance_gpu = l.rolling_variance_gpu;
net_map.layers[k].weights_gpu = l.weights_gpu;
net_map.layers[k].weights_gpu16 = l.weights_gpu16;
}
}
}
const int actual_batch_size = net.batch * net.subdivisions; const int actual_batch_size = net.batch * net.subdivisions;
if (actual_batch_size == 1) { if (actual_batch_size == 1) {
printf("\n Error: You set incorrect value batch=1 for Training! You should set batch=64 subdivision=64 \n"); printf("\n Error: You set incorrect value batch=1 for Training! You should set batch=64 subdivision=64 \n");
getchar(); getchar();
} }
else if (actual_batch_size < 64) { else if (actual_batch_size < 64) {
printf("\n Warning: You set batch=%d lower than 64! It is recommended to set batch=64 subdivision=64 \n", actual_batch_size); printf("\n Warning: You set batch=%d lower than 64! It is recommended to set batch=64 subdivision=64 \n", actual_batch_size);
} }
int imgs = net.batch * net.subdivisions * ngpus; int imgs = net.batch * net.subdivisions * ngpus;
@ -114,7 +135,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
iter_map = get_current_batch(net); iter_map = get_current_batch(net);
float mean_average_precision = -1; float mean_average_precision = -1;
load_args args = {0}; load_args args = { 0 };
args.w = net.w; args.w = net.w;
args.h = net.h; args.h = net.h;
args.c = net.c; args.c = net.c;
@ -137,7 +158,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
#ifdef OPENCV #ifdef OPENCV
args.threads = 3 * ngpus; // Amazon EC2 Tesla V100: p3.2xlarge (8 logical cores) - p3.16xlarge args.threads = 3 * ngpus; // Amazon EC2 Tesla V100: p3.2xlarge (8 logical cores) - p3.16xlarge
//args.threads = 12 * ngpus; // Ryzen 7 2700X (16 logical cores) //args.threads = 12 * ngpus; // Ryzen 7 2700X (16 logical cores)
IplImage* img = NULL; IplImage* img = NULL;
float max_img_loss = 5; float max_img_loss = 5;
int number_of_lines = 100; int number_of_lines = 100;
@ -150,8 +171,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
double time; double time;
int count = 0; int count = 0;
//while(i*imgs < N*120){ //while(i*imgs < N*120){
while(get_current_batch(net) < net.max_batches){ while (get_current_batch(net) < net.max_batches) {
if(l.random && count++%10 == 0){ if (l.random && count++ % 10 == 0) {
printf("Resizing\n"); printf("Resizing\n");
//int dim = (rand() % 12 + (init_w/32 - 5)) * 32; // +-160 //int dim = (rand() % 12 + (init_w/32 - 5)) * 32; // +-160
//int dim = (rand() % 4 + 16) * 32; //int dim = (rand() % 4 + 16) * 32;
@ -177,41 +198,42 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
free_data(train); free_data(train);
load_thread = load_data(args); load_thread = load_data(args);
for(i = 0; i < ngpus; ++i){ for (i = 0; i < ngpus; ++i) {
resize_network(nets + i, dim_w, dim_h); resize_network(nets + i, dim_w, dim_h);
} }
net = nets[0]; net = nets[0];
} }
time=what_time_is_it_now(); time = what_time_is_it_now();
pthread_join(load_thread, 0); pthread_join(load_thread, 0);
train = buffer; train = buffer;
load_thread = load_data(args); load_thread = load_data(args);
/* /*
int k; int k;
for(k = 0; k < l.max_boxes; ++k){ for(k = 0; k < l.max_boxes; ++k){
box b = float_to_box(train.y.vals[10] + 1 + k*5); box b = float_to_box(train.y.vals[10] + 1 + k*5);
if(!b.x) break; if(!b.x) break;
printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h); printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h);
} }
image im = float_to_image(448, 448, 3, train.X.vals[10]); image im = float_to_image(448, 448, 3, train.X.vals[10]);
int k; int k;
for(k = 0; k < l.max_boxes; ++k){ for(k = 0; k < l.max_boxes; ++k){
box b = float_to_box(train.y.vals[10] + 1 + k*5); box b = float_to_box(train.y.vals[10] + 1 + k*5);
printf("%d %d %d %d\n", truth.x, truth.y, truth.w, truth.h); printf("%d %d %d %d\n", truth.x, truth.y, truth.w, truth.h);
draw_bbox(im, b, 8, 1,0,0); draw_bbox(im, b, 8, 1,0,0);
} }
save_image(im, "truth11"); save_image(im, "truth11");
*/ */
printf("Loaded: %lf seconds\n", (what_time_is_it_now()-time)); printf("Loaded: %lf seconds\n", (what_time_is_it_now() - time));
time=what_time_is_it_now(); time = what_time_is_it_now();
float loss = 0; float loss = 0;
#ifdef GPU #ifdef GPU
if(ngpus == 1){ if (ngpus == 1) {
loss = train_network(net, train); loss = train_network(net, train);
} else { }
else {
loss = train_networks(nets, ngpus, train, 4); loss = train_networks(nets, ngpus, train, 4);
} }
#else #else
@ -222,31 +244,18 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
i = get_current_batch(net); i = get_current_batch(net);
if (net.cudnn_half) { if (net.cudnn_half) {
if (i < net.burn_in*3) printf("\n Tensor Cores are disabled until the first %d iterations are reached.", 3*net.burn_in); if (i < net.burn_in * 3) printf("\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
else printf("\n Tensor Cores are used."); else printf("\n Tensor Cores are used.");
} }
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now()-time), i*imgs); printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), i*imgs);
#ifdef OPENCV #ifdef OPENCV
if (!dont_show) { if (!dont_show) {
int draw_precision = 0; int draw_precision = 0;
int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions);
if (calc_map && (i >= (iter_map + calc_map_for_each) || i == net.max_batches) && i >= net.burn_in && i >= 1000) { if (calc_map && (i >= (iter_map + calc_map_for_each) || i == net.max_batches) && i >= net.burn_in && i >= 1000) {
if (l.random) {
printf("Resizing to initial size: %d x %d \n", init_w, init_h);
args.w = init_w;
args.h = init_h;
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data(args);
int k;
for (k = 0; k < ngpus; ++k) {
resize_network(nets + k, init_w, init_h);
}
net = nets[0];
}
iter_map = i; iter_map = i;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net); mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net_map);
printf("\n mean_average_precision = %f \n", mean_average_precision); printf("\n mean_average_precision = %f \n", mean_average_precision);
draw_precision = 1; draw_precision = 1;
} }
@ -257,7 +266,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) { //if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) {
//if (i % 100 == 0) { //if (i % 100 == 0) {
if(i >= (iter_save + 1000)) { if (i >= (iter_save + 1000)) {
iter_save = i; iter_save = i;
#ifdef GPU #ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0); if (ngpus != 1) sync_nets(nets, ngpus, 0);
@ -279,7 +288,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
free_data(train); free_data(train);
} }
#ifdef GPU #ifdef GPU
if(ngpus != 1) sync_nets(nets, ngpus, 0); if (ngpus != 1) sync_nets(nets, ngpus, 0);
#endif #endif
char buff[256]; char buff[256];
sprintf(buff, "%s/%s_final.weights", backup_directory, base); sprintf(buff, "%s/%s_final.weights", backup_directory, base);
@ -632,8 +641,6 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char *train_images = option_find_str(options, "train", "data/train.txt"); char *train_images = option_find_str(options, "train", "data/train.txt");
char *valid_images = option_find_str(options, "valid", train_images); char *valid_images = option_find_str(options, "valid", train_images);
net = *existing_net; net = *existing_net;
initial_batch = net.batch;
set_batch_network(&net, 1);
} }
else { else {
net = parse_network_cfg_custom(cfgfile, 1); // set batch=1 net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
@ -810,12 +817,12 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
// calc avg IoU, true-positives, false-positives for required Threshold // calc avg IoU, true-positives, false-positives for required Threshold
if (prob > thresh_calc_avg_iou) { if (prob > thresh_calc_avg_iou) {
int z, found = 0; int z, found = 0;
for (z = checkpoint_detections_count; z < detections_count-1; ++z) for (z = checkpoint_detections_count; z < detections_count - 1; ++z)
if (detections[z].unique_truth_index == truth_index) { if (detections[z].unique_truth_index == truth_index) {
found = 1; break; found = 1; break;
} }
if(truth_index > -1 && found == 0) { if (truth_index > -1 && found == 0) {
avg_iou += max_iou; avg_iou += max_iou;
++tp_for_thresh; ++tp_for_thresh;
} }
@ -844,7 +851,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
} }
} }
if((tp_for_thresh + fp_for_thresh) > 0) if ((tp_for_thresh + fp_for_thresh) > 0)
avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh); avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
@ -869,7 +876,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
int rank; int rank;
for (rank = 0; rank < detections_count; ++rank) { for (rank = 0; rank < detections_count; ++rank) {
if(rank % 100 == 0) if (rank % 100 == 0)
printf(" rank = %d of ranks = %d \r", rank, detections_count); printf(" rank = %d of ranks = %d \r", rank, detections_count);
if (rank > 0) { if (rank > 0) {
@ -932,7 +939,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
avg_precision += cur_precision; avg_precision += cur_precision;
} }
avg_precision = avg_precision / 11; avg_precision = avg_precision / 11;
printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision*100); printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision * 100);
mean_average_precision += avg_precision; mean_average_precision += avg_precision;
} }
@ -970,7 +977,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
free_list(options); free_list(options);
if (existing_net) { if (existing_net) {
set_batch_network(&net, initial_batch); //set_batch_network(&net, initial_batch);
} }
else { else {
free_network(net); free_network(net);
@ -998,7 +1005,7 @@ int anchors_data_comparator(const float **pa, const float **pb)
{ {
float *a = (float *)*pa; float *a = (float *)*pa;
float *b = (float *)*pb; float *b = (float *)*pb;
float diff = b[0]*b[1] - a[0]*a[1]; float diff = b[0] * b[1] - a[0] * a[1];
if (diff < 0) return 1; if (diff < 0) return 1;
else if (diff > 0) return -1; else if (diff > 0) return -1;
return 0; return 0;
@ -1054,7 +1061,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
rel_width_height_array = realloc(rel_width_height_array, 2 * number_of_boxes * sizeof(float)); rel_width_height_array = realloc(rel_width_height_array, 2 * number_of_boxes * sizeof(float));
rel_width_height_array[number_of_boxes * 2 - 2] = truth[j].w * width; rel_width_height_array[number_of_boxes * 2 - 2] = truth[j].w * width;
rel_width_height_array[number_of_boxes * 2 - 1] = truth[j].h * height; rel_width_height_array[number_of_boxes * 2 - 1] = truth[j].h * height;
printf("\r loaded \t image: %d \t box: %d", i+1, number_of_boxes); printf("\r loaded \t image: %d \t box: %d", i + 1, number_of_boxes);
} }
} }
printf("\n all loaded. \n"); printf("\n all loaded. \n");
@ -1086,7 +1093,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
for (i = 0; i < number_of_boxes; ++i) { for (i = 0; i < number_of_boxes; ++i) {
float box_w = rel_width_height_array[i * 2]; //points->data.fl[i * 2]; float box_w = rel_width_height_array[i * 2]; //points->data.fl[i * 2];
float box_h = rel_width_height_array[i * 2 + 1]; //points->data.fl[i * 2 + 1]; float box_h = rel_width_height_array[i * 2 + 1]; //points->data.fl[i * 2 + 1];
//int cluster_idx = labels->data.i[i]; //int cluster_idx = labels->data.i[i];
int cluster_idx = 0; int cluster_idx = 0;
float min_dist = FLT_MAX; float min_dist = FLT_MAX;
float best_iou = 0; float best_iou = 0;
@ -1121,7 +1128,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
for (i = 0; i < num_of_clusters; ++i) { for (i = 0; i < num_of_clusters; ++i) {
float anchor_w = anchors_data.centers.vals[i][0]; //centers->data.fl[i * 2]; float anchor_w = anchors_data.centers.vals[i][0]; //centers->data.fl[i * 2];
float anchor_h = anchors_data.centers.vals[i][1]; //centers->data.fl[i * 2 + 1]; float anchor_h = anchors_data.centers.vals[i][1]; //centers->data.fl[i * 2 + 1];
if(width > 32) sprintf(buff, "%3.0f,%3.0f", anchor_w, anchor_h); if (width > 32) sprintf(buff, "%3.0f,%3.0f", anchor_w, anchor_h);
else sprintf(buff, "%2.4f,%2.4f", anchor_w, anchor_h); else sprintf(buff, "%2.4f,%2.4f", anchor_w, anchor_h);
printf("%s", buff); printf("%s", buff);
fwrite(buff, sizeof(char), strlen(buff), fw); fwrite(buff, sizeof(char), strlen(buff), fw);
@ -1201,7 +1208,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
//#endif // OPENCV //#endif // OPENCV
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
float hier_thresh, int dont_show, int ext_output, int save_labels) float hier_thresh, int dont_show, int ext_output, int save_labels)
{ {
list *options = read_data_cfg(datacfg); list *options = read_data_cfg(datacfg);
char *name_list = option_find_str(options, "names", "data/names.list"); char *name_list = option_find_str(options, "names", "data/names.list");
@ -1210,7 +1217,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
image **alphabet = load_alphabet(); image **alphabet = load_alphabet();
network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1 network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
if(weightfile){ if (weightfile) {
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
//set_batch_network(&net, 1); //set_batch_network(&net, 1);
@ -1219,31 +1226,32 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
if (net.layers[net.n - 1].classes != names_size) { if (net.layers[net.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n", printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net.layers[net.n - 1].classes, cfgfile); name_list, names_size, net.layers[net.n - 1].classes, cfgfile);
if(net.layers[net.n - 1].classes > names_size) getchar(); if (net.layers[net.n - 1].classes > names_size) getchar();
} }
srand(2222222); srand(2222222);
double time; double time;
char buff[256]; char buff[256];
char *input = buff; char *input = buff;
int j; int j;
float nms=.45; // 0.4F float nms = .45; // 0.4F
while(1){ while (1) {
if(filename){ if (filename) {
strncpy(input, filename, 256); strncpy(input, filename, 256);
if(strlen(input) > 0) if (strlen(input) > 0)
if (input[strlen(input) - 1] == 0x0d) input[strlen(input) - 1] = 0; if (input[strlen(input) - 1] == 0x0d) input[strlen(input) - 1] = 0;
} else { }
else {
printf("Enter Image Path: "); printf("Enter Image Path: ");
fflush(stdout); fflush(stdout);
input = fgets(input, 256, stdin); input = fgets(input, 256, stdin);
if(!input) return; if (!input) return;
strtok(input, "\n"); strtok(input, "\n");
} }
image im = load_image(input,0,0,net.c); image im = load_image(input, 0, 0, net.c);
int letterbox = 0; int letterbox = 0;
image sized = resize_image(im, net.w, net.h); image sized = resize_image(im, net.w, net.h);
//image sized = letterbox_image(im, net.w, net.h); letterbox = 1; //image sized = letterbox_image(im, net.w, net.h); letterbox = 1;
layer l = net.layers[net.n-1]; layer l = net.layers[net.n - 1];
//box *boxes = calloc(l.w*l.h*l.n, sizeof(box)); //box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
//float **probs = calloc(l.w*l.h*l.n, sizeof(float *)); //float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
@ -1268,7 +1276,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
} }
// pseudo labeling concept - fast.ai // pseudo labeling concept - fast.ai
if(save_labels) if (save_labels)
{ {
char labelpath[4096]; char labelpath[4096];
replace_image_to_label(input, labelpath); replace_image_to_label(input, labelpath);
@ -1347,7 +1355,7 @@ void run_detector(int argc, char **argv)
// and for recall mode (extended output table-like format with results for best_class fit) // and for recall mode (extended output table-like format with results for best_class fit)
int ext_output = find_arg(argc, argv, "-ext_output"); int ext_output = find_arg(argc, argv, "-ext_output");
int save_labels = find_arg(argc, argv, "-save_labels"); int save_labels = find_arg(argc, argv, "-save_labels");
if(argc < 4){ if (argc < 4) {
fprintf(stderr, "usage: %s %s [train/test/valid/demo/map] [data] [cfg] [weights (optional)]\n", argv[0], argv[1]); fprintf(stderr, "usage: %s %s [train/test/valid/demo/map] [data] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return; return;
} }
@ -1355,20 +1363,21 @@ void run_detector(int argc, char **argv)
int *gpus = 0; int *gpus = 0;
int gpu = 0; int gpu = 0;
int ngpus = 0; int ngpus = 0;
if(gpu_list){ if (gpu_list) {
printf("%s\n", gpu_list); printf("%s\n", gpu_list);
int len = strlen(gpu_list); int len = strlen(gpu_list);
ngpus = 1; ngpus = 1;
int i; int i;
for(i = 0; i < len; ++i){ for (i = 0; i < len; ++i) {
if (gpu_list[i] == ',') ++ngpus; if (gpu_list[i] == ',') ++ngpus;
} }
gpus = calloc(ngpus, sizeof(int)); gpus = calloc(ngpus, sizeof(int));
for(i = 0; i < ngpus; ++i){ for (i = 0; i < ngpus; ++i) {
gpus[i] = atoi(gpu_list); gpus[i] = atoi(gpu_list);
gpu_list = strchr(gpu_list, ',')+1; gpu_list = strchr(gpu_list, ',') + 1;
} }
} else { }
else {
gpu = gpu_index; gpu = gpu_index;
gpus = &gpu; gpus = &gpu;
ngpus = 1; ngpus = 1;
@ -1379,23 +1388,23 @@ void run_detector(int argc, char **argv)
char *datacfg = argv[3]; char *datacfg = argv[3];
char *cfg = argv[4]; char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0; char *weights = (argc > 5) ? argv[5] : 0;
if(weights) if (weights)
if(strlen(weights) > 0) if (strlen(weights) > 0)
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0; if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
char *filename = (argc > 6) ? argv[6]: 0; char *filename = (argc > 6) ? argv[6] : 0;
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels); if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map); else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, NULL); else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, NULL);
else if(0==strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show); else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);
else if(0==strcmp(argv[2], "demo")) { else if (0 == strcmp(argv[2], "demo")) {
list *options = read_data_cfg(datacfg); list *options = read_data_cfg(datacfg);
int classes = option_find_int(options, "classes", 20); int classes = option_find_int(options, "classes", 20);
char *name_list = option_find_str(options, "names", "data/names.list"); char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list); char **names = get_labels(name_list);
if(filename) if (filename)
if(strlen(filename) > 0) if (strlen(filename) > 0)
if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0; if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename, demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename,
http_stream_port, dont_show, ext_output); http_stream_port, dont_show, ext_output);

@ -104,6 +104,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
l->indexes_gpu = cuda_make_int_array(output_size); l->indexes_gpu = cuda_make_int_array(output_size);
l->output_gpu = cuda_make_array(l->output, output_size); l->output_gpu = cuda_make_array(l->output, output_size);
l->delta_gpu = cuda_make_array(l->delta, output_size); l->delta_gpu = cuda_make_array(l->delta, output_size);
cudnn_maxpool_setup(l);
#endif #endif
} }

@ -145,6 +145,7 @@ YOLODLL_API void reset_rnn(network *net);
YOLODLL_API network *load_network_custom(char *cfg, char *weights, int clear, int batch); YOLODLL_API network *load_network_custom(char *cfg, char *weights, int clear, int batch);
YOLODLL_API network *load_network(char *cfg, char *weights, int clear); YOLODLL_API network *load_network(char *cfg, char *weights, int clear);
YOLODLL_API float *network_predict_image(network *net, image im); YOLODLL_API float *network_predict_image(network *net, image im);
YOLODLL_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net);
YOLODLL_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map); YOLODLL_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map);
YOLODLL_API int network_width(network *net); YOLODLL_API int network_width(network *net);
YOLODLL_API int network_height(network *net); YOLODLL_API int network_height(network *net);

Loading…
Cancel
Save