diff --git a/src/data.c b/src/data.c index a15bc1d7..dc803f0a 100644 --- a/src/data.c +++ b/src/data.c @@ -687,8 +687,9 @@ data load_data_swag(char **paths, int n, int classes, float jitter) #include "http_stream.h" -data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object) +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object) { + c = c ? c : 3; char **random_paths = get_random_paths(paths, n, m); int i; data d = {0}; @@ -696,13 +697,13 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in d.X.rows = n; d.X.vals = calloc(d.X.rows, sizeof(float*)); - d.X.cols = h*w*3; + d.X.cols = h*w*c; d.y = make_matrix(n, 5*boxes); for(i = 0; i < n; ++i){ const char *filename = random_paths[i]; - int flag = 1; + int flag = (c >= 3); IplImage *src; if ((src = cvLoadImage(filename, flag)) == 0) { @@ -754,8 +755,9 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in return d; } #else // OPENCV -data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object) +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object) { + c = c ? c : 3; char **random_paths = get_random_paths(paths, n, m); int i; data d = { 0 }; @@ -763,11 +765,11 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in d.X.rows = n; d.X.vals = calloc(d.X.rows, sizeof(float*)); - d.X.cols = h*w * 3; + d.X.cols = h*w*c; d.y = make_matrix(n, 5 * boxes); for (i = 0; i < n; ++i) { - image orig = load_image_color(random_paths[i], 0, 0); + image orig = load_image(random_paths[i], 0, 0, c); int oh = orig.h; int ow = orig.w; @@ -827,16 +829,16 @@ void *load_thread(void *ptr) } else if (a.type == REGION_DATA){ *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure); } else if (a.type == DETECTION_DATA){ - *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.small_object); + *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.small_object); } else if (a.type == SWAG_DATA){ *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter); } else if (a.type == COMPARE_DATA){ *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h); } else if (a.type == IMAGE_DATA){ - *(a.im) = load_image_color(a.path, 0, 0); + *(a.im) = load_image(a.path, 0, 0, a.c); *(a.resized) = resize_image(*(a.im), a.w, a.h); }else if (a.type == LETTERBOX_DATA) { - *(a.im) = load_image_color(a.path, 0, 0); + *(a.im) = load_image(a.path, 0, 0, a.c); *(a.resized) = letterbox_image(*(a.im), a.w, a.h); } else if (a.type == TAG_DATA){ *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.flip, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure); diff --git a/src/data.h b/src/data.h index 57f4702e..b46143fb 100644 --- a/src/data.h +++ b/src/data.h @@ -44,7 +44,8 @@ typedef struct load_args{ char **labels; int h; int w; - int out_w; + int c; // color depth + int out_w; int out_h; int nh; int nw; @@ -84,7 +85,7 @@ void print_letters(float *pred, int n); data load_data_captcha(char **paths, int n, int m, int k, int w, int h); data load_data_captcha_encode(char **paths, int n, int m, int w, int h); data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h); -data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object); +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object); data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); data load_data_super(char **paths, int n, int m, int w, int h, int scale); diff --git a/src/demo.c b/src/demo.c index 81eddb29..72f9c03e 100644 --- a/src/demo.c +++ b/src/demo.c @@ -51,7 +51,7 @@ static float *avg; void draw_detections_cv(IplImage* show_img, int num, float thresh, box *boxes, float **probs, char **names, image **alphabet, int classes); void draw_detections_cv_v3(IplImage* show_img, detection *dets, int num, float thresh, char **names, image **alphabet, int classes, int ext_output); void show_image_cv_ipl(IplImage *disp, const char *name); -image get_image_from_stream_resize(CvCapture *cap, int w, int h, IplImage** in_img, int cpp_video_capture); +image get_image_from_stream_resize(CvCapture *cap, int w, int h, int c, IplImage** in_img, int cpp_video_capture); IplImage* in_img; IplImage* det_img; IplImage* show_img; @@ -61,7 +61,7 @@ static int flag_exit; void *fetch_in_thread(void *ptr) { //in = get_image_from_stream(cap); - in_s = get_image_from_stream_resize(cap, net.w, net.h, &in_img, cpp_video_capture); + in_s = get_image_from_stream_resize(cap, net.w, net.h, net.c, &in_img, cpp_video_capture); if(!in_s.data){ //error("Stream closed."); printf("Stream closed.\n"); diff --git a/src/detector.c b/src/detector.c index 6fc6b677..202fbf9b 100644 --- a/src/detector.c +++ b/src/detector.c @@ -87,7 +87,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i load_args args = {0}; args.w = net.w; args.h = net.h; - args.paths = paths; + args.c = net.c; + args.paths = paths; args.n = imgs; args.m = plist->size; args.classes = classes; @@ -388,6 +389,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out load_args args = { 0 }; args.w = net.w; args.h = net.h; + args.c = net.c; args.type = IMAGE_DATA; //args.type = LETTERBOX_DATA; @@ -482,7 +484,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) for (i = 0; i < m; ++i) { char *path = paths[i]; - image orig = load_image_color(path, 0, 0); + image orig = load_image(path, 0, 0, net.c); image sized = resize_image(orig, net.w, net.h); char *id = basecfg(path); network_predict(net, sized.data); @@ -595,6 +597,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float load_args args = { 0 }; args.w = net.w; args.h = net.h; + args.c = net.c; args.type = IMAGE_DATA; //args.type = LETTERBOX_DATA; @@ -1093,7 +1096,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam if(!input) return; strtok(input, "\n"); } - image im = load_image_color(input,0,0); + image im = load_image(input,0,0,net.c); int letterbox = 0; //image sized = resize_image(im, net.w, net.h); image sized = letterbox_image(im, net.w, net.h); letterbox = 1; diff --git a/src/http_stream.cpp b/src/http_stream.cpp index 9192f75a..acb6c8ef 100644 --- a/src/http_stream.cpp +++ b/src/http_stream.cpp @@ -283,19 +283,26 @@ image image_data_augmentation(IplImage* ipl, int w, int h, // HSV augmentation // CV_BGR2HSV, CV_RGB2HSV, CV_HSV2BGR, CV_HSV2RGB - cv::Mat hsv_src; - cvtColor(sized, hsv_src, CV_BGR2HSV); // also BGR -> RGB + if (ipl->nChannels >= 3) + { + cv::Mat hsv_src; + cvtColor(sized, hsv_src, CV_BGR2HSV); // also BGR -> RGB - std::vector hsv; - cv::split(hsv_src, hsv); + std::vector hsv; + cv::split(hsv_src, hsv); - hsv[1] *= dsat; - hsv[2] *= dexp; - hsv[0] += 179 * dhue; + hsv[1] *= dsat; + hsv[2] *= dexp; + hsv[0] += 179 * dhue; - cv::merge(hsv, hsv_src); + cv::merge(hsv, hsv_src); - cvtColor(hsv_src, sized, CV_HSV2RGB); // now RGB instead of BGR + cvtColor(hsv_src, sized, CV_HSV2RGB); // now RGB instead of BGR + } + else + { + sized *= dexp; + } // Mat -> IplImage -> image IplImage src = sized; diff --git a/src/image.c b/src/image.c index 7545e7db..3ffd552e 100644 --- a/src/image.c +++ b/src/image.c @@ -957,7 +957,7 @@ image load_image_cv(char *filename, int channels) { IplImage* src = 0; int flag = -1; - if (channels == 0) flag = -1; + if (channels == 0) flag = 1; else if (channels == 1) flag = 0; else if (channels == 3) flag = 1; else { @@ -975,7 +975,8 @@ image load_image_cv(char *filename, int channels) } image out = ipl_to_image(src); cvReleaseImage(&src); - rgbgr_image(out); + if (out.c > 1) + rgbgr_image(out); return out; } @@ -1010,8 +1011,9 @@ image get_image_from_stream_cpp(CvCapture *cap) return im; } -image get_image_from_stream_resize(CvCapture *cap, int w, int h, IplImage** in_img, int cpp_video_capture) +image get_image_from_stream_resize(CvCapture *cap, int w, int h, int c, IplImage** in_img, int cpp_video_capture) { + c = c ? c : 3; IplImage* src; if (cpp_video_capture) { static int once = 1; @@ -1029,14 +1031,15 @@ image get_image_from_stream_resize(CvCapture *cap, int w, int h, IplImage** in_i if (!src) return make_empty_image(0, 0, 0); if (src->width < 1 || src->height < 1 || src->nChannels < 1) return make_empty_image(0, 0, 0); - IplImage* new_img = cvCreateImage(cvSize(w, h), IPL_DEPTH_8U, 3); - *in_img = cvCreateImage(cvSize(src->width, src->height), IPL_DEPTH_8U, 3); + IplImage* new_img = cvCreateImage(cvSize(w, h), IPL_DEPTH_8U, c); + *in_img = cvCreateImage(cvSize(src->width, src->height), IPL_DEPTH_8U, c); cvResize(src, *in_img, CV_INTER_LINEAR); cvResize(src, new_img, CV_INTER_LINEAR); image im = ipl_to_image(new_img); cvReleaseImage(&new_img); if (cpp_video_capture) cvReleaseImage(&src); - rgbgr_image(im); + if (c>1) + rgbgr_image(im); return im; } @@ -1589,16 +1592,23 @@ void exposure_image(image im, float sat) void distort_image(image im, float hue, float sat, float val) { - rgb_to_hsv(im); - scale_image_channel(im, 1, sat); - scale_image_channel(im, 2, val); - int i; - for(i = 0; i < im.w*im.h; ++i){ - im.data[i] = im.data[i] + hue; - if (im.data[i] > 1) im.data[i] -= 1; - if (im.data[i] < 0) im.data[i] += 1; - } - hsv_to_rgb(im); + if (im.c >= 3) + { + rgb_to_hsv(im); + scale_image_channel(im, 1, sat); + scale_image_channel(im, 2, val); + int i; + for(i = 0; i < im.w*im.h; ++i){ + im.data[i] = im.data[i] + hue; + if (im.data[i] > 1) im.data[i] -= 1; + if (im.data[i] < 0) im.data[i] += 1; + } + hsv_to_rgb(im); + } + else + { + scale_image_channel(im, 0, val); + } constrain_image(im); }