Added LSTM sequence detector, and blur data augmentation (for OpenCV only)

pull/749/head^2
AlexeyAB 6 years ago
parent cce34712f6
commit 31dc6c8680
  1. 2
      Makefile
  2. 8
      README.md
  3. 4
      build/darknet/darknet.vcxproj
  4. 2
      build/darknet/darknet_no_gpu.vcxproj
  5. 3
      build/darknet/x64/partial.cmd
  6. 2
      build/darknet/yolo_cpp_dll.vcxproj
  7. 2
      build/darknet/yolo_cpp_dll_no_gpu.vcxproj
  8. 18
      include/darknet.h
  9. 4
      src/convolutional_layer.c
  10. 12
      src/crnn_layer.c
  11. 1
      src/crnn_layer.h
  12. 28
      src/data.c
  13. 2
      src/data.h
  14. 2
      src/demo.c
  15. 29
      src/detector.c
  16. 38
      src/image_opencv.cpp
  17. 3
      src/image_opencv.h
  18. 83
      src/layer.c
  19. 1
      src/lstm_layer.c
  20. 1
      src/lstm_layer.h
  21. 31
      src/network.c
  22. 3
      src/network.h
  23. 16
      src/network_kernels.cu
  24. 58
      src/parser.c
  25. 1
      src/yolo_v2_class.cpp

@ -118,7 +118,7 @@ LDFLAGS+= -L/usr/local/zed/lib -lsl_core -lsl_input -lsl_zed
#-lstdc++ -D_GLIBCXX_USE_CXX11_ABI=0
endif
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o
ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o

@ -14,7 +14,7 @@ More details: http://pjreddie.com/darknet/yolo/
* [Requirements (and how to install dependecies)](#requirements)
* [Pre-trained models](#pre-trained-models)
* [Explanations in issues](https://github.com/AlexeyAB/darknet/issues?q=is%3Aopen+is%3Aissue+label%3AExplanations)
* [Yolo v3 in other frameworks (TensorFlow, OpenVINO, OpenCV-dnn, ...)](#yolo-v3-in-other-frameworks)
* [Yolo v3 in other frameworks (TensorFlow, PyTorch, OpenVINO, OpenCV-dnn,...)](#yolo-v3-in-other-frameworks)
0. [Improvements in this repository](#improvements-in-this-repository)
1. [How to use](#how-to-use-on-the-command-line)
@ -75,9 +75,9 @@ You can get cfg-files by path: `darknet/cfg/`
#### Yolo v3 in other frameworks
* **TensorFlow:** convert `yolov3.weights`/`cfg` files to `yolov3.ckpt`/`pb/meta`: by using [mystic123](https://github.com/mystic123/tensorflow-yolo-v3) or [jinyu121](https://github.com/jinyu121/DW2TF) projects, and [TensorFlow-lite](https://www.tensorflow.org/lite/guide/get_started#2_convert_the_model_format)
* **Intel OpenVINO:** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model)
* **OpenCV-dnn** is very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150)
* **PyTorch > ONNX > CoreML > iOS** how to convert cfg/weights-files to pt-file: [ultralytics/yolov3](https://github.com/ultralytics/yolov3#darknet-conversion)
* **Intel OpenVINO 2019 R1:** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model)
* **OpenCV-dnn** is a very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150)
* **PyTorch > ONNX > CoreML > iOS** how to convert cfg/weights-files to pt-file: [ultralytics/yolov3](https://github.com/ultralytics/yolov3#darknet-conversion) and [iOS App](https://itunes.apple.com/app/id1452689527)
##### Examples of results

@ -140,6 +140,8 @@
<CompileAs>Default</CompileAs>
<UndefinePreprocessorDefinitions>NDEBUG</UndefinePreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalUsingDirectories>
</AdditionalUsingDirectories>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
@ -183,6 +185,7 @@
<ClCompile Include="..\..\src\compare.c" />
<ClCompile Include="..\..\src\connected_layer.c" />
<ClCompile Include="..\..\src\convolutional_layer.c" />
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
<ClCompile Include="..\..\src\cost_layer.c" />
<ClCompile Include="..\..\src\cpu_gemm.c" />
<ClCompile Include="..\..\src\crnn_layer.c" />
@ -248,6 +251,7 @@
<ClInclude Include="..\..\src\col2im.h" />
<ClInclude Include="..\..\src\connected_layer.h" />
<ClInclude Include="..\..\src\convolutional_layer.h" />
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
<ClInclude Include="..\..\src\cost_layer.h" />
<ClInclude Include="..\..\src\crnn_layer.h" />
<ClInclude Include="..\..\src\crop_layer.h" />

@ -189,6 +189,7 @@
<ClCompile Include="..\..\src\compare.c" />
<ClCompile Include="..\..\src\connected_layer.c" />
<ClCompile Include="..\..\src\convolutional_layer.c" />
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
<ClCompile Include="..\..\src\cost_layer.c" />
<ClCompile Include="..\..\src\cpu_gemm.c" />
<ClCompile Include="..\..\src\crnn_layer.c" />
@ -254,6 +255,7 @@
<ClInclude Include="..\..\src\col2im.h" />
<ClInclude Include="..\..\src\connected_layer.h" />
<ClInclude Include="..\..\src\convolutional_layer.h" />
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
<ClInclude Include="..\..\src\cost_layer.h" />
<ClInclude Include="..\..\src\crnn_layer.h" />
<ClInclude Include="..\..\src\crop_layer.h" />

@ -33,6 +33,9 @@ darknet.exe partial cfg/yolov3-spp.cfg yolov3-spp.weights yolov3-spp.conv.85 85
darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15
darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.14 14
darknet.exe partial cfg/yolo9000.cfg yolo9000.weights yolo9000.conv.22 22

@ -187,6 +187,7 @@
<ClCompile Include="..\..\src\compare.c" />
<ClCompile Include="..\..\src\connected_layer.c" />
<ClCompile Include="..\..\src\convolutional_layer.c" />
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
<ClCompile Include="..\..\src\cost_layer.c" />
<ClCompile Include="..\..\src\cpu_gemm.c" />
<ClCompile Include="..\..\src\crnn_layer.c" />
@ -254,6 +255,7 @@
<ClInclude Include="..\..\src\col2im.h" />
<ClInclude Include="..\..\src\connected_layer.h" />
<ClInclude Include="..\..\src\convolutional_layer.h" />
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
<ClInclude Include="..\..\src\cost_layer.h" />
<ClInclude Include="..\..\src\crnn_layer.h" />
<ClInclude Include="..\..\src\crop_layer.h" />

@ -173,6 +173,7 @@
<ClCompile Include="..\..\src\compare.c" />
<ClCompile Include="..\..\src\connected_layer.c" />
<ClCompile Include="..\..\src\convolutional_layer.c" />
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
<ClCompile Include="..\..\src\cost_layer.c" />
<ClCompile Include="..\..\src\cpu_gemm.c" />
<ClCompile Include="..\..\src\crnn_layer.c" />
@ -240,6 +241,7 @@
<ClInclude Include="..\..\src\col2im.h" />
<ClInclude Include="..\..\src\connected_layer.h" />
<ClInclude Include="..\..\src\convolutional_layer.h" />
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
<ClInclude Include="..\..\src\cost_layer.h" />
<ClInclude Include="..\..\src\crnn_layer.h" />
<ClInclude Include="..\..\src\crop_layer.h" />

@ -32,7 +32,6 @@
#endif
#endif
#define NFRAMES 3
#define SECRET_NUM -1234
#ifdef GPU
@ -136,6 +135,7 @@ typedef enum {
RNN,
GRU,
LSTM,
CONV_LSTM,
CRNN,
BATCHNORM,
NETWORK,
@ -208,6 +208,7 @@ struct layer {
int index;
int binary;
int xnor;
int peephole;
int use_bin_output;
int steps;
int hidden;
@ -354,6 +355,7 @@ struct layer {
float *z_cpu;
float *r_cpu;
float *h_cpu;
float *stored_h_cpu;
float * prev_state_cpu;
float *temp_cpu;
@ -369,6 +371,7 @@ struct layer {
float *g_cpu;
float *o_cpu;
float *c_cpu;
float *stored_c_cpu;
float *dc_cpu;
float *binary_input;
@ -407,10 +410,13 @@ struct layer {
struct layer *uh;
struct layer *uo;
struct layer *wo;
struct layer *vo;
struct layer *uf;
struct layer *wf;
struct layer *vf;
struct layer *ui;
struct layer *wi;
struct layer *vi;
struct layer *ug;
struct layer *wg;
@ -424,6 +430,7 @@ struct layer {
float *z_gpu;
float *r_gpu;
float *h_gpu;
float *stored_h_gpu;
float *temp_gpu;
float *temp2_gpu;
@ -432,12 +439,16 @@ struct layer {
float *dh_gpu;
float *hh_gpu;
float *prev_cell_gpu;
float *prev_state_gpu;
float *last_prev_state_gpu;
float *last_prev_cell_gpu;
float *cell_gpu;
float *f_gpu;
float *i_gpu;
float *g_gpu;
float *o_gpu;
float *c_gpu;
float *stored_c_gpu;
float *dc_gpu;
// adam
@ -451,7 +462,6 @@ struct layer {
float * combine_gpu;
float * combine_delta_gpu;
float * prev_state_gpu;
float * forgot_state_gpu;
float * forgot_delta_gpu;
float * state_gpu;
@ -571,6 +581,7 @@ typedef struct network {
float min_ratio;
int center;
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
int blur;
float angle;
float aspect;
float exposure;
@ -579,6 +590,8 @@ typedef struct network {
int random;
int track;
int augment_speed;
int sequential_subdivisions;
int current_subdivision;
int try_fix_nan;
int gpu_index;
@ -713,6 +726,7 @@ typedef struct load_args {
int show_imgs;
float jitter;
int flip;
int blur;
float angle;
float aspect;
float saturation;

@ -446,8 +446,8 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
#ifdef CUDNN_HALF
l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weights, c*n*size*size / 2);
l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weight_updates, c*n*size*size / 2);
l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2 + 1); //cuda_make_array(l.weights, c*n*size*size / 2);
l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2 + 1); //cuda_make_array(l.weight_updates, c*n*size*size / 2);
#endif
l.biases_gpu = cuda_make_array(l.biases, n);

@ -85,6 +85,8 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.delta_gpu = l.output_layer->delta_gpu;
#endif
l.bflops = l.input_layer->bflops + l.self_layer->bflops + l.output_layer->bflops;
return l;
}
@ -128,6 +130,16 @@ void resize_crnn_layer(layer *l, int w, int h)
#endif
}
void free_state_crnn(layer l)
{
int i;
for (i = 0; i < l.outputs * l.batch; ++i) l.self_layer->output[i] = rand_uniform(-1, 1);
#ifdef GPU
cuda_push_array(l.self_layer->output_gpu, l.self_layer->output, l.outputs * l.batch);
#endif // GPU
}
void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay)
{
update_convolutional_layer(*(l.input_layer), batch, learning_rate, momentum, decay);

@ -11,6 +11,7 @@ extern "C" {
#endif
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor);
void resize_crnn_layer(layer *l, int w, int h);
void free_state_crnn(layer l);
void forward_crnn_layer(layer l, network_state state);
void backward_crnn_layer(layer l, network_state state);

@ -231,6 +231,15 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float
boxes[i].h = 999999;
continue;
}
if ((boxes[i].x + boxes[i].w / 2) < 0 || (boxes[i].y + boxes[i].h / 2) < 0 ||
(boxes[i].x - boxes[i].w / 2) > 1 || (boxes[i].y - boxes[i].h / 2) > 1)
{
boxes[i].x = 999999;
boxes[i].y = 999999;
boxes[i].w = 999999;
boxes[i].h = 999999;
continue;
}
boxes[i].left = boxes[i].left * sx - dx;
boxes[i].right = boxes[i].right * sx - dx;
boxes[i].top = boxes[i].top * sy - dy;
@ -378,7 +387,7 @@ void fill_truth_detection(const char *path, int num_boxes, float *truth, int cla
continue;
}
if (x == 999999 || y == 999999) {
printf("\n Wrong annotation: x = 0, y = 0 \n");
printf("\n Wrong annotation: x = 0, y = 0, < 0 or > 1 \n");
sprintf(buff, "echo %s \"Wrong annotation: x = 0 or y = 0\" >> bad_label.list", labelpath);
system(buff);
++sub;
@ -769,7 +778,7 @@ static box float_to_box_stride(float *f, int stride)
#include "http_stream.h"
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
{
c = c ? c : 3;
@ -785,7 +794,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
d.X.cols = h*w*c;
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
float dhue = 0, dsat = 0, dexp = 0, flip = 0;
float dhue = 0, dsat = 0, dexp = 0, flip = 0, blur = 0;
int augmentation_calculated = 0;
d.y = make_matrix(n, 5*boxes);
@ -819,6 +828,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
dexp = rand_scale(exposure);
flip = use_flip ? random_gen() % 2 : 0;
blur = rand_int(0, 1) ? (use_blur) : 0;
}
int pleft = rand_precalc_random(-dw, dw, r1);
@ -835,10 +845,12 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
float dx = ((float)pleft/ow)/sx;
float dy = ((float)ptop /oh)/sy;
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp);
d.X.vals[i] = ai.data;
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h);
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp,
blur, boxes, d.y.vals[i]);
d.X.vals[i] = ai.data;
if(show_imgs)
{
@ -869,7 +881,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
return d;
}
#else // OPENCV
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
{
c = c ? c : 3;
@ -989,7 +1001,7 @@ void *load_thread(void *ptr)
} else if (a.type == REGION_DATA){
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter,
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.jitter,
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);

@ -86,7 +86,7 @@ void print_letters(float *pred, int n);
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs);
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);

@ -37,6 +37,8 @@ static int demo_ext_output = 0;
static long long int frame_id = 0;
static int demo_json_port = -1;
#define NFRAMES 3
static float* predictions[NFRAMES];
static int demo_index = 0;
static image images[NFRAMES];

@ -47,6 +47,15 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
for (k = 0; k < net_map.n; ++k) {
free_layer(net_map.layers[k]);
}
char *name_list = option_find_str(options, "names", "data/names.list");
int names_size = 0;
char **names = get_labels_custom(name_list, &names_size);
if (net_map.layers[net_map.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net_map.layers[net_map.n - 1].classes, cfgfile);
if (net_map.layers[net_map.n - 1].classes > names_size) getchar();
}
}
srand(time(0));
@ -119,6 +128,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
args.threads = 64; // 16 or 64
args.angle = net.angle;
args.blur = net.blur;
args.exposure = net.exposure;
args.saturation = net.saturation;
args.hue = net.hue;
@ -137,7 +147,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
if (net.track) {
args.track = net.track;
args.augment_speed = net.augment_speed;
args.threads = net.subdivisions * ngpus; // 2 * ngpus;
if (net.sequential_subdivisions) args.threads = net.sequential_subdivisions * ngpus;
else args.threads = net.subdivisions * ngpus; // 2 * ngpus;
args.mini_batch = net.batch / net.time_steps;
printf("\n Tracking! batch = %d, subdiv = %d, time_steps = %d, mini_batch = %d \n", net.batch, net.subdivisions, net.time_steps, args.mini_batch);
}
@ -223,7 +234,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
calc_map_for_each = fmax(calc_map_for_each, 100);
int next_map_calc = iter_map + calc_map_for_each;
next_map_calc = fmax(next_map_calc, net.burn_in);
next_map_calc = fmax(next_map_calc, 1000);
next_map_calc = fmax(next_map_calc, 400);
if (calc_map) {
printf("\n (next mAP calculation at %d iterations) ", next_map_calc);
if (mean_average_precision > 0) printf("\n Last accuracy mAP@0.5 = %2.2f %% ", mean_average_precision * 100);
@ -638,7 +649,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char *valid_images = option_find_str(options, "valid", "data/train.txt");
char *difficult_valid_images = option_find_str(options, "difficult", NULL);
char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list);
int names_size = 0;
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
//char *mapf = option_find_str(options, "map", 0);
//int *map = 0;
//if (mapf) map = read_map(mapf);
@ -650,6 +662,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char *train_images = option_find_str(options, "train", "data/train.txt");
valid_images = option_find_str(options, "valid", train_images);
net = *existing_net;
remember_network_recurrent_state(*existing_net);
free_network_recurrent_state(*existing_net);
}
else {
net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
@ -660,6 +674,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
}
if (net.layers[net.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net.layers[net.n - 1].classes, cfgfile);
getchar();
}
srand(time(0));
printf("\n calculation mAP (mean average precision)...\n");
@ -1053,6 +1072,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
if (existing_net) {
//set_batch_network(&net, initial_batch);
//free_network_recurrent_state(*existing_net);
restore_network_recurrent_state(*existing_net);
}
else {
free_network(net);
@ -1220,7 +1241,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
if (show) {
#ifdef OPENCV
//show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height);
show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height);
#endif // OPENCV
}
free(rel_width_height_array);

@ -1125,9 +1125,20 @@ void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_im
// ====================================================================
// Data augmentation
// ====================================================================
static box float_to_box_stride(float *f, int stride)
{
box b = { 0 };
b.x = f[0];
b.y = f[1 * stride];
b.w = f[2 * stride];
b.h = f[3 * stride];
return b;
}
image image_data_augmentation(mat_cv* mat, int w, int h,
int pleft, int ptop, int swidth, int sheight, int flip,
float jitter, float dhue, float dsat, float dexp)
float jitter, float dhue, float dsat, float dexp,
int blur, int num_boxes, float *truth)
{
image out;
try {
@ -1192,6 +1203,31 @@ image image_data_augmentation(mat_cv* mat, int w, int h,
//cv::imshow(window_name.str(), sized);
//cv::waitKey(0);
if (blur) {
cv::Mat dst(sized.size(), sized.type());
if(blur == 1) cv::GaussianBlur(sized, dst, cv::Size(31, 31), 0);
else cv::GaussianBlur(sized, dst, cv::Size((blur / 2) * 2 + 1, (blur / 2) * 2 + 1), 0);
cv::Rect img_rect(0, 0, sized.cols, sized.rows);
//std::cout << " blur num_boxes = " << num_boxes << std::endl;
if (blur == 1) {
int t;
for (t = 0; t < num_boxes; ++t) {
box b = float_to_box_stride(truth + t*(4 + 1), 1);
if (!b.x) break;
int left = (b.x - b.w / 2.)*sized.cols;
int width = b.w*sized.cols;
int top = (b.y - b.h / 2.)*sized.rows;
int height = b.h*sized.rows;
cv::Rect roi(left, top, width, height);
roi = roi & img_rect;
sized(roi).copyTo(dst(roi));
}
}
dst.copyTo(sized);
}
// Mat -> image
out = mat_to_image(sized);
}

@ -95,7 +95,8 @@ void draw_train_loss(mat_cv* img, int img_size, float avg_loss, float max_img_lo
// Data augmentation
image image_data_augmentation(mat_cv* mat, int w, int h,
int pleft, int ptop, int swidth, int sheight, int flip,
float jitter, float dhue, float dsat, float dexp);
float jitter, float dhue, float dsat, float dexp,
int blur, int num_boxes, float *truth);
// Show Anchors
void show_acnhors(int number_of_boxes, int num_of_clusters, float *rel_width_height_array, model anchors_data, int width, int height);

@ -2,22 +2,40 @@
#include "dark_cuda.h"
#include <stdlib.h>
void free_sublayer(layer *l)
{
if (l) {
free_layer(*l);
free(l);
}
}
void free_layer(layer l)
{
// free layers: input_layer, self_layer, output_layer, ...
if (l.type == CRNN) {
if (l.input_layer) {
free_layer(*l.input_layer);
free(l.input_layer);
}
if (l.self_layer) {
free_layer(*l.self_layer);
free(l.self_layer);
if (l.type == CONV_LSTM) {
if (l.peephole) {
free_sublayer(l.vf);
free_sublayer(l.vi);
free_sublayer(l.vo);
}
if (l.output_layer) {
free_layer(*l.output_layer);
free(l.output_layer);
else {
free(l.vf);
free(l.vi);
free(l.vo);
}
free_sublayer(l.wf);
free_sublayer(l.wi);
free_sublayer(l.wg);
free_sublayer(l.wo);
free_sublayer(l.uf);
free_sublayer(l.ui);
free_sublayer(l.ug);
free_sublayer(l.uo);
}
if (l.type == CRNN) {
free_sublayer(l.input_layer);
free_sublayer(l.self_layer);
free_sublayer(l.output_layer);
l.output = NULL;
l.delta = NULL;
#ifdef GPU
@ -83,21 +101,36 @@ void free_layer(layer l)
if (l.v) free(l.v);
if (l.z_cpu) free(l.z_cpu);
if (l.r_cpu) free(l.r_cpu);
if (l.h_cpu) free(l.h_cpu);
if (l.binary_input) free(l.binary_input);
if (l.bin_re_packed_input) free(l.bin_re_packed_input);
if (l.t_bit_input) free(l.t_bit_input);
if (l.loss) free(l.loss);
// CONV-LSTM
if (l.f_cpu) free(l.f_cpu);
if (l.i_cpu) free(l.i_cpu);
if (l.g_cpu) free(l.g_cpu);
if (l.o_cpu) free(l.o_cpu);
if (l.c_cpu) free(l.c_cpu);
if (l.h_cpu) free(l.h_cpu);
if (l.temp_cpu) free(l.temp_cpu);
if (l.temp2_cpu) free(l.temp2_cpu);
if (l.temp3_cpu) free(l.temp3_cpu);
if (l.dc_cpu) free(l.dc_cpu);
if (l.dh_cpu) free(l.dh_cpu);
if (l.prev_state_cpu) free(l.prev_state_cpu);
if (l.prev_cell_cpu) free(l.prev_cell_cpu);
if (l.stored_c_cpu) free(l.stored_c_cpu);
if (l.stored_h_cpu) free(l.stored_h_cpu);
if (l.cell_cpu) free(l.cell_cpu);
#ifdef GPU
if (l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
if (l.z_gpu) cuda_free(l.z_gpu);
if (l.r_gpu) cuda_free(l.r_gpu);
if (l.h_gpu) cuda_free(l.h_gpu);
if (l.m_gpu) cuda_free(l.m_gpu);
if (l.v_gpu) cuda_free(l.v_gpu);
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
if (l.forgot_state_gpu) cuda_free(l.forgot_state_gpu);
if (l.forgot_delta_gpu) cuda_free(l.forgot_delta_gpu);
if (l.state_gpu) cuda_free(l.state_gpu);
@ -137,5 +170,25 @@ void free_layer(layer l)
if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.squared_gpu) cuda_free(l.squared_gpu);
if (l.norms_gpu) cuda_free(l.norms_gpu);
// CONV-LSTM
if (l.f_gpu) cuda_free(l.f_gpu);
if (l.i_gpu) cuda_free(l.i_gpu);
if (l.g_gpu) cuda_free(l.g_gpu);
if (l.o_gpu) cuda_free(l.o_gpu);
if (l.c_gpu) cuda_free(l.c_gpu);
if (l.h_gpu) cuda_free(l.h_gpu);
if (l.temp_gpu) cuda_free(l.temp_gpu);
if (l.temp2_gpu) cuda_free(l.temp2_gpu);
if (l.temp3_gpu) cuda_free(l.temp3_gpu);
if (l.dc_gpu) cuda_free(l.dc_gpu);
if (l.dh_gpu) cuda_free(l.dh_gpu);
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
if (l.prev_cell_gpu) cuda_free(l.prev_cell_gpu);
if (l.stored_c_gpu) cuda_free(l.stored_c_gpu);
if (l.stored_h_gpu) cuda_free(l.stored_h_gpu);
if (l.last_prev_state_gpu) cuda_free(l.last_prev_state_gpu);
if (l.last_prev_cell_gpu) cuda_free(l.last_prev_cell_gpu);
if (l.cell_gpu) cuda_free(l.cell_gpu);
#endif
}

@ -95,6 +95,7 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n
l.forward = forward_lstm_layer;
l.update = update_lstm_layer;
l.backward = backward_lstm_layer;
l.prev_state_cpu = (float*)calloc(batch*outputs, sizeof(float));
l.prev_cell_cpu = (float*)calloc(batch*outputs, sizeof(float));

@ -12,6 +12,7 @@ extern "C" {
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
void forward_lstm_layer(layer l, network_state state);
void backward_lstm_layer(layer l, network_state state);
void update_lstm_layer(layer l, int batch, float learning_rate, float momentum, float decay);
#ifdef GPU

@ -15,6 +15,7 @@
#include "gru_layer.h"
#include "rnn_layer.h"
#include "crnn_layer.h"
#include "conv_lstm_layer.h"
#include "local_layer.h"
#include "convolutional_layer.h"
#include "activation_layer.h"
@ -315,6 +316,7 @@ float train_network_sgd(network net, data d, int n)
float sum = 0;
for(i = 0; i < n; ++i){
get_random_batch(d, batch, X, y);
net.current_subdivision = i;
float err = train_network_datum(net, X, y);
sum += err;
}
@ -340,6 +342,7 @@ float train_network_waitkey(network net, data d, int wait_key)
float sum = 0;
for(i = 0; i < n; ++i){
get_next_batch(d, batch, i*batch, X, y);
net.current_subdivision = i;
float err = train_network_datum(net, X, y);
sum += err;
if(wait_key) wait_key_cv(5);
@ -1111,3 +1114,31 @@ network combine_train_valid_networks(network net_train, network net_map)
}
return net_combined;
}
void free_network_recurrent_state(network net)
{
int k;
for (k = 0; k < net.n; ++k) {
if (net.layers[k].type == CONV_LSTM) free_state_conv_lstm(net.layers[k]);
if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
}
}
void remember_network_recurrent_state(network net)
{
int k;
for (k = 0; k < net.n; ++k) {
if (net.layers[k].type == CONV_LSTM) remember_state_conv_lstm(net.layers[k]);
//if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
}
}
void restore_network_recurrent_state(network net)
{
int k;
for (k = 0; k < net.n; ++k) {
if (net.layers[k].type == CONV_LSTM) restore_state_conv_lstm(net.layers[k]);
if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
}
}

@ -163,6 +163,9 @@ int get_network_background(network net);
//LIB_API void calculate_binary_weights(network net);
network combine_train_valid_networks(network net_train, network net_map);
void copy_weights_net(network net_train, network *net_map);
void free_network_recurrent_state(network net);
void remember_network_recurrent_state(network net);
void restore_network_recurrent_state(network net);
#ifdef __cplusplus
}

@ -171,6 +171,22 @@ void forward_backward_network_gpu(network net, float *x, float *y)
cuda_convert_f32_to_f16(l.self_layer->weights_gpu, l.self_layer->nweights, l.self_layer->weights_gpu16);
cuda_convert_f32_to_f16(l.output_layer->weights_gpu, l.output_layer->nweights, l.output_layer->weights_gpu16);
}
else if (l.type == CONV_LSTM && l.wf->weights_gpu && l.wf->weights_gpu16) {
assert((l.wf->c * l.wf->n * l.wf->size * l.wf->size) > 0);
if (l.peephole) {
cuda_convert_f32_to_f16(l.vf->weights_gpu, l.vf->nweights, l.vf->weights_gpu16);
cuda_convert_f32_to_f16(l.vi->weights_gpu, l.vi->nweights, l.vi->weights_gpu16);
cuda_convert_f32_to_f16(l.vo->weights_gpu, l.vo->nweights, l.vo->weights_gpu16);
}
cuda_convert_f32_to_f16(l.wf->weights_gpu, l.wf->nweights, l.wf->weights_gpu16);
cuda_convert_f32_to_f16(l.wi->weights_gpu, l.wi->nweights, l.wi->weights_gpu16);
cuda_convert_f32_to_f16(l.wg->weights_gpu, l.wg->nweights, l.wg->weights_gpu16);
cuda_convert_f32_to_f16(l.wo->weights_gpu, l.wo->nweights, l.wo->weights_gpu16);
cuda_convert_f32_to_f16(l.uf->weights_gpu, l.uf->nweights, l.uf->weights_gpu16);
cuda_convert_f32_to_f16(l.ui->weights_gpu, l.ui->nweights, l.ui->weights_gpu16);
cuda_convert_f32_to_f16(l.ug->weights_gpu, l.ug->nweights, l.ug->weights_gpu16);
cuda_convert_f32_to_f16(l.uo->weights_gpu, l.uo->nweights, l.uo->weights_gpu16);
}
}
}
#endif

@ -20,6 +20,7 @@
#include "list.h"
#include "local_layer.h"
#include "lstm_layer.h"
#include "conv_lstm_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "option_list.h"
@ -61,6 +62,7 @@ LAYER_TYPE string_to_layer_type(char * type)
if (strcmp(type, "[crnn]")==0) return CRNN;
if (strcmp(type, "[gru]")==0) return GRU;
if (strcmp(type, "[lstm]")==0) return LSTM;
if (strcmp(type, "[conv_lstm]") == 0) return CONV_LSTM;
if (strcmp(type, "[rnn]")==0) return RNN;
if (strcmp(type, "[conn]")==0
|| strcmp(type, "[connected]")==0) return CONNECTED;
@ -239,6 +241,29 @@ layer parse_lstm(list *options, size_params params)
return l;
}
layer parse_conv_lstm(list *options, size_params params)
{
// a ConvLSTM with a larger transitional kernel should be able to capture faster motions
int size = option_find_int_quiet(options, "size", 3);
int stride = option_find_int_quiet(options, "stride", 1);
int pad = option_find_int_quiet(options, "pad", 0);
int padding = option_find_int_quiet(options, "padding", 0);
if (pad) padding = size / 2;
int output_filters = option_find_int(options, "output", 1);
char *activation_s = option_find_str(options, "activation", "LINEAR");
ACTIVATION activation = get_activation(activation_s);
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
int peephole = option_find_int_quiet(options, "peephole", 1);
layer l = make_conv_lstm_layer(params.batch, params.w, params.h, params.c, output_filters, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor);
l.shortcut = option_find_int_quiet(options, "shortcut", 0);
return l;
}
connected_layer parse_connected(list *options, size_params params)
{
int output = option_find_int(options, "output",1);
@ -647,6 +672,7 @@ void parse_net_options(list *options, network *net)
net->time_steps = option_find_int_quiet(options, "time_steps",1);
net->track = option_find_int_quiet(options, "track", 0);
net->augment_speed = option_find_int_quiet(options, "augment_speed", 2);
net->sequential_subdivisions = option_find_int_quiet(options, "sequential_subdivisions", 0);
net->try_fix_nan = option_find_int_quiet(options, "try_fix_nan", 0);
net->batch /= subdivs;
net->batch *= net->time_steps;
@ -666,6 +692,7 @@ void parse_net_options(list *options, network *net)
net->max_crop = option_find_int_quiet(options, "max_crop",net->w*2);
net->min_crop = option_find_int_quiet(options, "min_crop",net->w);
net->flip = option_find_int_quiet(options, "flip", 1);
net->blur = option_find_int_quiet(options, "blur", 0);
net->angle = option_find_float_quiet(options, "angle", 0);
net->aspect = option_find_float_quiet(options, "aspect", 1);
@ -789,6 +816,8 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
l = parse_gru(options, params);
}else if(lt == LSTM){
l = parse_lstm(options, params);
}else if (lt == CONV_LSTM) {
l = parse_conv_lstm(options, params);
}else if(lt == CRNN){
l = parse_crnn(options, params);
}else if(lt == CONNECTED){
@ -1076,6 +1105,20 @@ void save_weights_upto(network net, char *filename, int cutoff)
save_connected_weights(*(l.ui), fp);
save_connected_weights(*(l.ug), fp);
save_connected_weights(*(l.uo), fp);
} if (l.type == CONV_LSTM) {
if (l.peephole) {
save_convolutional_weights(*(l.vf), fp);
save_convolutional_weights(*(l.vi), fp);
save_convolutional_weights(*(l.vo), fp);
}
save_convolutional_weights(*(l.wf), fp);
save_convolutional_weights(*(l.wi), fp);
save_convolutional_weights(*(l.wg), fp);
save_convolutional_weights(*(l.wo), fp);
save_convolutional_weights(*(l.uf), fp);
save_convolutional_weights(*(l.ui), fp);
save_convolutional_weights(*(l.ug), fp);
save_convolutional_weights(*(l.uo), fp);
} if(l.type == CRNN){
save_convolutional_weights(*(l.input_layer), fp);
save_convolutional_weights(*(l.self_layer), fp);
@ -1298,6 +1341,21 @@ void load_weights_upto(network *net, char *filename, int cutoff)
load_connected_weights(*(l.ug), fp, transpose);
load_connected_weights(*(l.uo), fp, transpose);
}
if (l.type == CONV_LSTM) {
if (l.peephole) {
load_convolutional_weights(*(l.vf), fp);
load_convolutional_weights(*(l.vi), fp);
load_convolutional_weights(*(l.vo), fp);
}
load_convolutional_weights(*(l.wf), fp);
load_convolutional_weights(*(l.wi), fp);
load_convolutional_weights(*(l.wg), fp);
load_convolutional_weights(*(l.wo), fp);
load_convolutional_weights(*(l.uf), fp);
load_convolutional_weights(*(l.ui), fp);
load_convolutional_weights(*(l.ug), fp);
load_convolutional_weights(*(l.uo), fp);
}
if(l.type == LOCAL){
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;

@ -22,6 +22,7 @@ extern "C" {
#include <algorithm>
#include <cmath>
#define NFRAMES 3
//static Detector* detector = NULL;
static std::unique_ptr<Detector> detector;

Loading…
Cancel
Save