Merge branch 'master' into monitor_training

pull/4976/head
Muhammad Maaz 5 years ago committed by GitHub
commit 02e7109b8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      README.md
  2. 35
      include/darknet.h
  3. 2
      src/activation_kernels.cu
  4. 12
      src/batchnorm_layer.c
  5. 2
      src/batchnorm_layer.h
  6. 3
      src/blas.h
  7. 73
      src/blas_kernels.cu
  8. 18
      src/classifier.c
  9. 11
      src/connected_layer.c
  10. 2
      src/connected_layer.h
  11. 24
      src/conv_lstm_layer.c
  12. 2
      src/conv_lstm_layer.h
  13. 25
      src/convolutional_kernels.cu
  14. 2
      src/convolutional_layer.c
  15. 2
      src/convolutional_layer.h
  16. 8
      src/crnn_layer.c
  17. 2
      src/crnn_layer.h
  18. 20
      src/dark_cuda.c
  19. 2
      src/dark_cuda.h
  20. 15
      src/darknet.c
  21. 18
      src/data.c
  22. 109
      src/detector.c
  23. 11
      src/dropout_layer.c
  24. 27
      src/dropout_layer_kernels.cu
  25. 11
      src/gaussian_yolo_layer.c
  26. 14
      src/gru_layer.c
  27. 2
      src/gru_layer.h
  28. 14
      src/http_stream.cpp
  29. 2
      src/image_opencv.cpp
  30. 2
      src/local_layer.c
  31. 2
      src/local_layer.h
  32. 18
      src/lstm_layer.c
  33. 2
      src/lstm_layer.h
  34. 22
      src/network.c
  35. 1
      src/network.h
  36. 16
      src/network_kernels.cu
  37. 40
      src/parser.c
  38. 3
      src/region_layer.c
  39. 5
      src/rnn.c
  40. 8
      src/rnn_layer.c
  41. 2
      src/rnn_layer.h
  42. 30
      src/shortcut_layer.c
  43. 2
      src/shortcut_layer.h
  44. 5
      src/tag.c
  45. 2
      src/yolo_console_dll.cpp
  46. 25
      src/yolo_layer.c

@ -1,5 +1,5 @@
# Yolo-v3 and Yolo-v2 for Windows and Linux # Yolo-v3 and Yolo-v2 for Windows and Linux
### (neural network for object detection) - Tensor Cores can be used on [Linux](https://github.com/AlexeyAB/darknet#how-to-compile-on-linux) and [Windows](https://github.com/AlexeyAB/darknet#how-to-compile-on-windows-using-vcpkg) ### (neural network for object detection) - Tensor Cores can be used on [Linux](https://github.com/AlexeyAB/darknet#how-to-compile-on-linux) and [Windows](https://github.com/AlexeyAB/darknet#how-to-compile-on-windows-using-cmake-gui)
More details: http://pjreddie.com/darknet/yolo/ More details: http://pjreddie.com/darknet/yolo/
@ -9,6 +9,7 @@ More details: http://pjreddie.com/darknet/yolo/
[![AppveyorCI](https://ci.appveyor.com/api/projects/status/594bwb5uoc1fxwiu/branch/master?svg=true)](https://ci.appveyor.com/project/AlexeyAB/darknet/branch/master) [![AppveyorCI](https://ci.appveyor.com/api/projects/status/594bwb5uoc1fxwiu/branch/master?svg=true)](https://ci.appveyor.com/project/AlexeyAB/darknet/branch/master)
[![Contributors](https://img.shields.io/github/contributors/AlexeyAB/Darknet.svg)](https://github.com/AlexeyAB/darknet/graphs/contributors) [![Contributors](https://img.shields.io/github/contributors/AlexeyAB/Darknet.svg)](https://github.com/AlexeyAB/darknet/graphs/contributors)
[![License: Unlicense](https://img.shields.io/badge/license-Unlicense-blue.svg)](https://github.com/AlexeyAB/darknet/blob/master/LICENSE) [![License: Unlicense](https://img.shields.io/badge/license-Unlicense-blue.svg)](https://github.com/AlexeyAB/darknet/blob/master/LICENSE)
[![DOI](https://zenodo.org/badge/75388965.svg)](https://zenodo.org/badge/latestdoi/75388965)
* [Requirements (and how to install dependecies)](#requirements) * [Requirements (and how to install dependecies)](#requirements)
@ -395,7 +396,7 @@ Training Yolo v3:
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L610 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L610
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L696 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L696
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L783 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L783
* change [`filters=255`] to filters=(classes + 5)x3 in the 3 `[convolutional]` before each `[yolo]` layer * change [`filters=255`] to filters=(classes + 5)x3 in the 3 `[convolutional]` before each `[yolo]` layer, keep in mind that it only has to be the last `[convolutional]` before each of the `[yolo]` layers.
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L603 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L603
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L689 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L689
* https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L776 * https://github.com/AlexeyAB/darknet/blob/0039fd26786ab5f71d5af725fc18b3f521e7acfd/cfg/yolov3.cfg#L776

@ -34,6 +34,8 @@
#define SECRET_NUM -1234 #define SECRET_NUM -1234
typedef enum { UNUSED_DEF_VAL } UNUSED_ENUM_TYPE;
#ifdef GPU #ifdef GPU
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -42,8 +44,8 @@
#ifdef CUDNN #ifdef CUDNN
#include <cudnn.h> #include <cudnn.h>
#endif #endif // CUDNN
#endif #endif // GPU
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -209,13 +211,14 @@ struct layer {
void(*update) (struct layer, int, float, float, float); void(*update) (struct layer, int, float, float, float);
void(*forward_gpu) (struct layer, struct network_state); void(*forward_gpu) (struct layer, struct network_state);
void(*backward_gpu) (struct layer, struct network_state); void(*backward_gpu) (struct layer, struct network_state);
void(*update_gpu) (struct layer, int, float, float, float); void(*update_gpu) (struct layer, int, float, float, float, float);
layer *share_layer; layer *share_layer;
int train; int train;
int avgpool; int avgpool;
int batch_normalize; int batch_normalize;
int shortcut; int shortcut;
int batch; int batch;
int dynamic_minibatch;
int forced; int forced;
int flipped; int flipped;
int inputs; int inputs;
@ -321,6 +324,8 @@ struct layer {
int onlyforward; int onlyforward;
int stopbackward; int stopbackward;
int dont_update;
int burnin_update;
int dontload; int dontload;
int dontsave; int dontsave;
int dontloadscales; int dontloadscales;
@ -495,7 +500,7 @@ struct layer {
size_t workspace_size; size_t workspace_size;
#ifdef GPU //#ifdef GPU
int *indexes_gpu; int *indexes_gpu;
float *z_gpu; float *z_gpu;
@ -610,8 +615,21 @@ struct layer {
cudnnConvolutionBwdDataAlgo_t bd_algo, bd_algo16; cudnnConvolutionBwdDataAlgo_t bd_algo, bd_algo16;
cudnnConvolutionBwdFilterAlgo_t bf_algo, bf_algo16; cudnnConvolutionBwdFilterAlgo_t bf_algo, bf_algo16;
cudnnPoolingDescriptor_t poolingDesc; cudnnPoolingDescriptor_t poolingDesc;
#else // CUDNN
void* srcTensorDesc, *dstTensorDesc;
void* srcTensorDesc16, *dstTensorDesc16;
void* dsrcTensorDesc, *ddstTensorDesc;
void* dsrcTensorDesc16, *ddstTensorDesc16;
void* normTensorDesc, *normDstTensorDesc, *normDstTensorDescF16;
void* weightDesc, *weightDesc16;
void* dweightDesc, *dweightDesc16;
void* convDesc;
UNUSED_ENUM_TYPE fw_algo, fw_algo16;
UNUSED_ENUM_TYPE bd_algo, bd_algo16;
UNUSED_ENUM_TYPE bf_algo, bf_algo16;
void* poolingDesc;
#endif // CUDNN #endif // CUDNN
#endif // GPU //#endif // GPU
}; };
@ -625,6 +643,8 @@ typedef struct network {
int n; int n;
int batch; int batch;
uint64_t *seen; uint64_t *seen;
int *cur_iteration;
float loss_scale;
int *t; int *t;
float epoch; float epoch;
int subdivisions; int subdivisions;
@ -701,7 +721,7 @@ typedef struct network {
float *cost; float *cost;
float clip; float clip;
#ifdef GPU //#ifdef GPU
//float *input_gpu; //float *input_gpu;
//float *truth_gpu; //float *truth_gpu;
float *delta_gpu; float *delta_gpu;
@ -722,8 +742,9 @@ typedef struct network {
float *global_delta_gpu; float *global_delta_gpu;
float *state_delta_gpu; float *state_delta_gpu;
size_t max_delta_gpu_size; size_t max_delta_gpu_size;
#endif //#endif // GPU
int optimized_memory; int optimized_memory;
int dynamic_minibatch;
size_t workspace_size_limit; size_t workspace_size_limit;
} network; } network;

@ -493,6 +493,7 @@ __global__ void activate_array_normalize_channels_softmax_kernel(float *x, int s
for (k = 0; k < channels; ++k) { for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels]; float val = x[wh_i + k * wh_step + b*wh_step*channels];
val = expf(val - max_val) / sum; val = expf(val - max_val) / sum;
if (isnan(val) || isinf(val)) val = 0;
output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val; output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
} }
} }
@ -535,6 +536,7 @@ __global__ void gradient_array_normalize_channels_softmax_kernel(float *x, int s
float delta = delta_gpu[index]; float delta = delta_gpu[index];
float grad = x[index] * (1 - x[index]); float grad = x[index] * (1 - x[index]);
delta = delta * grad; delta = delta * grad;
if (isnan(delta) || isinf(delta)) delta = 0;
delta_gpu[index] = delta; delta_gpu[index] = delta;
} }
} }

@ -258,15 +258,17 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu); fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
//fast_v_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.v_cbn_gpu); //fast_v_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.v_cbn_gpu);
int minibatch_index = state.net.current_subdivision + 1; const int minibatch_index = state.net.current_subdivision + 1;
float alpha = 0.01; const int max_minibatch_index = state.net.subdivisions;
//printf("\n minibatch_index = %d, max_minibatch_index = %d \n", minibatch_index, max_minibatch_index);
const float alpha = 0.01;
int inverse_variance = 0; int inverse_variance = 0;
#ifdef CUDNN #ifdef CUDNN
inverse_variance = 1; inverse_variance = 1;
#endif // CUDNN #endif // CUDNN
fast_v_cbn_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, minibatch_index, l.m_cbn_avg_gpu, l.v_cbn_avg_gpu, l.variance_gpu, fast_v_cbn_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, minibatch_index, max_minibatch_index, l.m_cbn_avg_gpu, l.v_cbn_avg_gpu, l.variance_gpu,
alpha, l.rolling_mean_gpu, l.rolling_variance_gpu, inverse_variance, .00001); alpha, l.rolling_mean_gpu, l.rolling_variance_gpu, inverse_variance, .00001);
normalize_scale_bias_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.scales_gpu, l.biases_gpu, l.batch, l.out_c, l.out_h*l.out_w, inverse_variance, .00001f); normalize_scale_bias_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.scales_gpu, l.biases_gpu, l.batch, l.out_c, l.out_h*l.out_w, inverse_variance, .00001f);
@ -385,9 +387,9 @@ void backward_batchnorm_layer_gpu(layer l, network_state state)
} }
} }
void update_batchnorm_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay) void update_batchnorm_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale)
{ {
float learning_rate = learning_rate_init*l.learning_rate_scale; float learning_rate = learning_rate_init * l.learning_rate_scale / loss_scale;
//float momentum = a.momentum; //float momentum = a.momentum;
//float decay = a.decay; //float decay = a.decay;
//int batch = a.batch; //int batch = a.batch;

@ -18,7 +18,7 @@ void resize_batchnorm_layer(layer *l, int w, int h);
#ifdef GPU #ifdef GPU
void forward_batchnorm_layer_gpu(layer l, network_state state); void forward_batchnorm_layer_gpu(layer l, network_state state);
void backward_batchnorm_layer_gpu(layer l, network_state state); void backward_batchnorm_layer_gpu(layer l, network_state state);
void update_batchnorm_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay); void update_batchnorm_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale);
void pull_batchnorm_layer(layer l); void pull_batchnorm_layer(layer l);
void push_batchnorm_layer(layer l); void push_batchnorm_layer(layer l);
#endif #endif

@ -60,6 +60,7 @@ void fix_nan_and_inf_cpu(float *input, size_t size);
#ifdef GPU #ifdef GPU
void constrain_weight_updates_ongpu(int N, float coef, float *weights_gpu, float *weight_updates_gpu);
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
void simple_copy_ongpu(int size, float *src, float *dst); void simple_copy_ongpu(int size, float *src, float *dst);
@ -87,7 +88,7 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean); void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance, void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon); const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon);
void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon); void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon);
void compare_2_arrays_gpu(float *one, float *two, int size); void compare_2_arrays_gpu(float *one, float *two, int size);

@ -418,6 +418,24 @@ __global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, in
//else out[0] = x[0]; //else out[0] = x[0];
} }
__global__ void constrain_weight_updates_kernel(int N, float coef, float *weights_gpu, float *weight_updates_gpu)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < N) {
const float w = weights_gpu[i];
const float wu = weight_updates_gpu[i];
const float wu_sign = (wu == 0) ? 0 : (fabs(wu) / wu);
const float abs_limit = fabs(w * coef);
if (fabs(wu) > abs_limit) weight_updates_gpu[i] = abs_limit * wu_sign;
}
}
extern "C" void constrain_weight_updates_ongpu(int N, float coef, float *weights_gpu, float *weight_updates_gpu)
{
constrain_weight_updates_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, coef, weights_gpu, weight_updates_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
{ {
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -572,7 +590,7 @@ extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters,
} }
__global__ void fast_v_cbn_kernel(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance, __global__ void fast_v_cbn_kernel(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon) const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
{ {
const int threads = BLOCK; const int threads = BLOCK;
@ -615,16 +633,19 @@ __global__ void fast_v_cbn_kernel(const float *x, float *mean, int batch, int f
if (inverse_variance) variance[filter] = 1.0f / sqrtf(variance_tmp + epsilon); if (inverse_variance) variance[filter] = 1.0f / sqrtf(variance_tmp + epsilon);
else variance[filter] = variance_tmp; else variance[filter] = variance_tmp;
//if (max_minibatch_index == minibatch_index)
{
rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter]; rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter];
rolling_variance_gpu[filter] = alpha * variance_tmp + (1 - alpha) * rolling_variance_gpu[filter]; rolling_variance_gpu[filter] = alpha * variance_tmp + (1 - alpha) * rolling_variance_gpu[filter];
} }
} }
}
extern "C" void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, float *m_avg, float *v_avg, float *variance, extern "C" void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance,
const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon) const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
{ {
fast_v_cbn_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, mean, batch, filters, spatial, minibatch_index, m_avg, v_avg, variance, alpha, rolling_mean_gpu, rolling_variance_gpu, inverse_variance, epsilon); fast_v_cbn_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, mean, batch, filters, spatial, minibatch_index, max_minibatch_index, m_avg, v_avg, variance, alpha, rolling_mean_gpu, rolling_variance_gpu, inverse_variance, epsilon);
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
} }
@ -798,6 +819,12 @@ __device__ float relu(float src) {
return 0; return 0;
} }
__device__ float lrelu(float src) {
const float eps = 0.001;
if (src > eps) return src;
return eps;
}
__global__ void shortcut_singlelayer_simple_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalizion) __global__ void shortcut_singlelayer_simple_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalizion)
{ {
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -849,7 +876,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
for (int i = 0; i < (n + 1); ++i) { for (int i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float w = weights_gpu[weights_index]; const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) sum += relu(w); if (weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
} }
} }
@ -858,7 +885,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
if (weights_gpu) { if (weights_gpu) {
float w = weights_gpu[src_i / step]; float w = weights_gpu[src_i / step];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
out_val = in[id] * w; // [0 or c or (c, h ,w)] out_val = in[id] * w; // [0 or c or (c, h ,w)]
@ -876,7 +903,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
if (weights_gpu) { if (weights_gpu) {
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
float w = weights_gpu[weights_index]; float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
out_val += add[add_index] * w; // [0 or c or (c, h ,w)] out_val += add[add_index] * w; // [0 or c or (c, h ,w)]
@ -932,34 +959,27 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
for (i = 0; i < (n + 1); ++i) { for (i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float w = weights_gpu[weights_index]; const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) sum += relu(w); if (weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
} }
/*
grad = 0;
for (i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float delta_w = delta_in[id] * in[id];
const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) grad += delta_w * relu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad += delta_w * expf(w - max_val) / sum;
}
*/
} }
if (weights_gpu) { if (weights_gpu) {
float w = weights_gpu[src_i / step]; float w = weights_gpu[src_i / step];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
if (weights_normalizion == RELU_NORMALIZATION) grad = w; if (weights_normalizion == RELU_NORMALIZATION) grad = w;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1-w); else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1-w);
delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)] delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)]
float weights_update_tmp = delta_in[id] * in[id] * grad; float weights_update_tmp = delta_in[id] * in[id] * grad;// / step;
if (layer_step == 1 && (size/32) > (id/32 + 1)) { if (layer_step == 1 && (size/32) > (id/32 + 1)) {
if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
weights_update_tmp = 0;
}
float wu = warpAllReduceSum(weights_update_tmp); float wu = warpAllReduceSum(weights_update_tmp);
if (threadIdx.x % 32 == 0) { if (threadIdx.x % 32 == 0) {
if (!isnan(wu) && !isinf(wu)) if (!isnan(wu) && !isinf(wu))
@ -994,13 +1014,18 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1 - w); else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1 - w);
layer_delta[add_index] += delta_in[id] * w; layer_delta[add_index] += delta_in[id] * w;
float weights_update_tmp = delta_in[id] * add[add_index] * grad; float weights_update_tmp = delta_in[id] * add[add_index] * grad;// / step;
if (layer_step == 1 && (size / 32) > (id / 32 + 1)) { if (layer_step == 1 && (size / 32) > (id / 32 + 1)) {
if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
weights_update_tmp = 0;
}
float wu = warpAllReduceSum(weights_update_tmp); float wu = warpAllReduceSum(weights_update_tmp);
if (threadIdx.x % 32 == 0) { if (threadIdx.x % 32 == 0) {
if (!isnan(wu) && !isinf(wu)) if (!isnan(wu) && !isinf(wu))
atomicAdd(&weight_updates_gpu[weights_index], wu); atomicAdd(&weight_updates_gpu[weights_index], wu);
//if(weights_gpu[weights_index] != 1) printf(" wu = %f, weights_update_tmp = %f, w = %f, weights_gpu[weights_index] = %f, grad = %f, weights_normalizion = %d ",
// wu, weights_update_tmp, w, weights_gpu[weights_index], grad, weights_normalizion);
} }
} }
else { else {
@ -1357,8 +1382,9 @@ __global__ void fix_nan_and_inf_kernel(float *input, size_t size)
const int index = blockIdx.x*blockDim.x + threadIdx.x; const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) { if (index < size) {
float val = input[index]; float val = input[index];
if (isnan(val) || isinf(val)) if (isnan(val) || isinf(val)) {
input[index] = 1.0f / index; // pseudo random value input[index] = 1.0f / (fabs((float)index) + 1); // pseudo random value
}
} }
} }
@ -1377,10 +1403,11 @@ __global__ void reset_nan_and_inf_kernel(float *input, size_t size)
const int index = blockIdx.x*blockDim.x + threadIdx.x; const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) { if (index < size) {
float val = input[index]; float val = input[index];
if (isnan(val) || isinf(val)) if (isnan(val) || isinf(val)) {
input[index] = 0; input[index] = 0;
} }
} }
}
extern "C" void reset_nan_and_inf(float *input, size_t size) extern "C" void reset_nan_and_inf(float *input, size_t size)
{ {

@ -48,7 +48,10 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
if(weightfile){ if(weightfile){
load_weights(&nets[i], weightfile); load_weights(&nets[i], weightfile);
} }
if(clear) *nets[i].seen = 0; if (clear) {
*nets[i].seen = 0;
*nets[i].cur_iteration = 0;
}
nets[i].learning_rate *= ngpus; nets[i].learning_rate *= ngpus;
} }
srand(time(0)); srand(time(0));
@ -166,8 +169,8 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
calc_topk_for_each = fmax(calc_topk_for_each, 100); calc_topk_for_each = fmax(calc_topk_for_each, 100);
if (i % 10 == 0) { if (i % 10 == 0) {
if (calc_topk) { if (calc_topk) {
fprintf(stderr, "\n (next TOP5 calculation at %d iterations) ", calc_topk_for_each); fprintf(stderr, "\n (next TOP%d calculation at %d iterations) ", topk_data, calc_topk_for_each);
if (topk > 0) fprintf(stderr, " Last accuracy TOP5 = %2.2f %% \n", topk * 100); if (topk > 0) fprintf(stderr, " Last accuracy TOP%d = %2.2f %% \n", topk_data, topk * 100);
} }
if (net.cudnn_half) { if (net.cudnn_half) {
@ -179,7 +182,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
int draw_precision = 0; int draw_precision = 0;
if (calc_topk && (i >= calc_topk_for_each || i == net.max_batches)) { if (calc_topk && (i >= calc_topk_for_each || i == net.max_batches)) {
iter_topk = i; iter_topk = i;
topk = validate_classifier_single(datacfg, cfgfile, weightfile, &net, topk_data); // calc TOP5 topk = validate_classifier_single(datacfg, cfgfile, weightfile, &net, topk_data); // calc TOP-n
printf("\n accuracy %s = %f \n", topk_buff, topk); printf("\n accuracy %s = %f \n", topk_buff, topk);
draw_precision = 1; draw_precision = 1;
} }
@ -767,13 +770,14 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
float *predictions = network_predict(net, X); float *predictions = network_predict(net, X);
layer l = net.layers[layer_num]; layer l = net.layers[layer_num];
for(int i = 0; i < l.c; ++i){ int i;
for(i = 0; i < l.c; ++i){
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]); if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
} }
#ifdef GPU #ifdef GPU
cuda_pull_array(l.output_gpu, l.output, l.outputs); cuda_pull_array(l.output_gpu, l.output, l.outputs);
#endif #endif
for(int i = 0; i < l.outputs; ++i){ for(i = 0; i < l.outputs; ++i){
printf("%f\n", l.output[i]); printf("%f\n", l.output[i]);
} }
/* /*
@ -791,7 +795,7 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
top_predictions(net, top, indexes); top_predictions(net, top, indexes);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
for(int i = 0; i < top; ++i){ for(i = 0; i < top; ++i){
int index = indexes[i]; int index = indexes[i];
printf("%s: %f\n", names[index], predictions[index]); printf("%s: %f\n", names[index], predictions[index]);
} }

@ -308,8 +308,17 @@ void push_connected_layer(connected_layer l)
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
} }
void update_connected_layer_gpu(connected_layer l, int batch, float learning_rate, float momentum, float decay) void update_connected_layer_gpu(connected_layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale)
{ {
float learning_rate = learning_rate_init * l.learning_rate_scale;
// Loss scale for Mixed-Precision on Tensor-Cores
if (loss_scale != 1.0) {
scal_ongpu(l.inputs*l.outputs, 1.0 / loss_scale, l.weight_updates_gpu, 1);
scal_ongpu(l.outputs, 1.0 / loss_scale, l.bias_updates_gpu, 1);
scal_ongpu(l.outputs, 1.0 / loss_scale, l.scale_updates_gpu, 1);
}
axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1); axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1); scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1);

@ -22,7 +22,7 @@ void statistics_connected_layer(layer l);
#ifdef GPU #ifdef GPU
void forward_connected_layer_gpu(connected_layer layer, network_state state); void forward_connected_layer_gpu(connected_layer layer, network_state state);
void backward_connected_layer_gpu(connected_layer layer, network_state state); void backward_connected_layer_gpu(connected_layer layer, network_state state);
void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay); void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_connected_layer(connected_layer layer); void push_connected_layer(connected_layer layer);
void pull_connected_layer(connected_layer layer); void pull_connected_layer(connected_layer layer);
#endif #endif

@ -791,21 +791,21 @@ void push_conv_lstm_layer(layer l)
push_convolutional_layer(*(l.uo)); push_convolutional_layer(*(l.uo));
} }
void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
if (l.peephole) { if (l.peephole) {
update_convolutional_layer_gpu(*(l.vf), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.vf), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.vi), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.vi), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.vo), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.vo), batch, learning_rate, momentum, decay, loss_scale);
} }
update_convolutional_layer_gpu(*(l.wf), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.wf), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.wi), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.wi), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.wg), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.wg), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.wo), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.wo), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.uf), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.uf), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.ui), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.ui), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.ug), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.ug), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.uo), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.uo), batch, learning_rate, momentum, decay, loss_scale);
} }
void forward_conv_lstm_layer_gpu(layer l, network_state state) void forward_conv_lstm_layer_gpu(layer l, network_state state)

@ -23,7 +23,7 @@ void update_conv_lstm_layer(layer l, int batch, float learning_rate, float momen
#ifdef GPU #ifdef GPU
void forward_conv_lstm_layer_gpu(layer l, network_state state); void forward_conv_lstm_layer_gpu(layer l, network_state state);
void backward_conv_lstm_layer_gpu(layer l, network_state state); void backward_conv_lstm_layer_gpu(layer l, network_state state);
void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus

@ -419,7 +419,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//#ifdef CUDNN_HALF //#ifdef CUDNN_HALF
//if (state.use_mixed_precision) { //if (state.use_mixed_precision) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3 * state.net.burn_in) && if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3 * state.net.burn_in) &&
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1) (l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
{ {
@ -671,7 +671,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
//#ifdef CUDNN_HALF //#ifdef CUDNN_HALF
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3 * state.net.burn_in) && if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3 * state.net.burn_in) &&
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1) (l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
{ {
@ -717,7 +717,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.normDstTensorDescF16, l.normDstTensorDescF16,
delta16, // input delta16, // input
l.normDstTensorDescF16, l.normDstTensorDescF16,
l.x_norm_gpu, // output (new delta) l.output_gpu, //l.x_norm_gpu, // output (new delta)
l.normTensorDesc, l.normTensorDesc,
l.scales_gpu, // input (should be FP32) l.scales_gpu, // input (should be FP32)
l.scale_updates_gpu, // output (should be FP32) l.scale_updates_gpu, // output (should be FP32)
@ -726,7 +726,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.mean_gpu, // input (should be FP32) l.mean_gpu, // input (should be FP32)
l.variance_gpu)); // input (should be FP32) l.variance_gpu)); // input (should be FP32)
simple_copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, delta16); simple_copy_ongpu(l.outputs*l.batch / 2, l.output_gpu, delta16);
//copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1); //copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1);
//cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream()); //cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream());
} }
@ -978,7 +978,7 @@ void assisted_activation2_gpu(float alpha, float *output, float *gt_gpu, float *
void assisted_excitation_forward_gpu(convolutional_layer l, network_state state) void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
{ {
const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); const int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
// epoch // epoch
//const float epoch = (float)(*state.net.seen) / state.net.train_images_num; //const float epoch = (float)(*state.net.seen) / state.net.train_images_num;
@ -1188,7 +1188,7 @@ void push_convolutional_layer(convolutional_layer l)
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
} }
void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay) void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale)
{ {
/* /*
@ -1232,6 +1232,13 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
//float decay = a.decay; //float decay = a.decay;
//int batch = a.batch; //int batch = a.batch;
// Loss scale for Mixed-Precision on Tensor-Cores
if (loss_scale != 1.0) {
if (l.weight_updates_gpu && l.nweights > 0) scal_ongpu(l.nweights, 1.0 / loss_scale, l.weight_updates_gpu, 1);
if (l.bias_updates_gpu && l.n > 0) scal_ongpu(l.n, 1.0 / loss_scale, l.bias_updates_gpu, 1);
if (l.scale_updates_gpu && l.n > 0) scal_ongpu(l.n, 1.0 / loss_scale, l.scale_updates_gpu, 1);
}
reset_nan_and_inf(l.weight_updates_gpu, l.nweights); reset_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights); fix_nan_and_inf(l.weights_gpu, l.nweights);
@ -1280,9 +1287,9 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
//} //}
} }
//if (l.clip) { if (l.clip) {
// constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1); constrain_ongpu(l.nweights, l.clip, l.weights_gpu, 1);
//} }
} }
/* /*

@ -786,7 +786,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
if (l->activation == SWISH || l->activation == MISH) l->activation_input = (float*)realloc(l->activation_input, total_batch*l->outputs * sizeof(float)); if (l->activation == SWISH || l->activation == MISH) l->activation_input = (float*)realloc(l->activation_input, total_batch*l->outputs * sizeof(float));
#ifdef GPU #ifdef GPU
if (old_w < w || old_h < h) { if (old_w < w || old_h < h || l->dynamic_minibatch) {
if (l->train) { if (l->train) {
cuda_free(l->delta_gpu); cuda_free(l->delta_gpu);
l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs); l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs);

@ -15,7 +15,7 @@ extern "C" {
#ifdef GPU #ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state); void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state);
void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state); void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state);
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay); void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_convolutional_layer(convolutional_layer layer); void push_convolutional_layer(convolutional_layer layer);
void pull_convolutional_layer(convolutional_layer layer); void pull_convolutional_layer(convolutional_layer layer);

@ -265,11 +265,11 @@ void push_crnn_layer(layer l)
push_convolutional_layer(*(l.output_layer)); push_convolutional_layer(*(l.output_layer));
} }
void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
update_convolutional_layer_gpu(*(l.input_layer), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.input_layer), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.self_layer), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.self_layer), batch, learning_rate, momentum, decay, loss_scale);
update_convolutional_layer_gpu(*(l.output_layer), batch, learning_rate, momentum, decay); update_convolutional_layer_gpu(*(l.output_layer), batch, learning_rate, momentum, decay, loss_scale);
} }
void forward_crnn_layer_gpu(layer l, network_state state) void forward_crnn_layer_gpu(layer l, network_state state)

@ -20,7 +20,7 @@ void update_crnn_layer(layer l, int batch, float learning_rate, float momentum,
#ifdef GPU #ifdef GPU
void forward_crnn_layer_gpu(layer l, network_state state); void forward_crnn_layer_gpu(layer l, network_state state);
void backward_crnn_layer_gpu(layer l, network_state state); void backward_crnn_layer_gpu(layer l, network_state state);
void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_crnn_layer(layer l); void push_crnn_layer(layer l);
void pull_crnn_layer(layer l); void pull_crnn_layer(layer l);
#endif #endif

@ -1,6 +1,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int cuda_debug_sync = 0;
int gpu_index = 0; int gpu_index = 0;
#ifdef __cplusplus #ifdef __cplusplus
} }
@ -19,6 +20,7 @@ int gpu_index = 0;
#pragma comment(lib, "cuda.lib") #pragma comment(lib, "cuda.lib")
#ifdef CUDNN #ifdef CUDNN
#ifndef USE_CMAKE_LIBS #ifndef USE_CMAKE_LIBS
#pragma comment(lib, "cudnn.lib") #pragma comment(lib, "cudnn.lib")
@ -29,6 +31,7 @@ int gpu_index = 0;
#error "If you set CUDNN_HALF=1 then you must set CUDNN=1" #error "If you set CUDNN_HALF=1 then you must set CUDNN=1"
#endif #endif
void cuda_set_device(int n) void cuda_set_device(int n)
{ {
gpu_index = n; gpu_index = n;
@ -86,10 +89,13 @@ void check_error_extended(cudaError_t status, const char *file, int line, const
check_error(status); check_error(status);
} }
#if defined(DEBUG) || defined(CUDA_DEBUG) #if defined(DEBUG) || defined(CUDA_DEBUG)
cuda_debug_sync = 1;
#endif
if (cuda_debug_sync) {
status = cudaDeviceSynchronize(); status = cudaDeviceSynchronize();
if (status != cudaSuccess) if (status != cudaSuccess)
printf("CUDA status = cudaDeviceSynchronize() Error: file: %s() : line: %d : build time: %s \n", file, line, date_time); printf("CUDA status = cudaDeviceSynchronize() Error: file: %s() : line: %d : build time: %s \n", file, line, date_time);
#endif }
check_error(status); check_error(status);
} }
@ -173,6 +179,9 @@ void cudnn_check_error(cudnnStatus_t status)
#if defined(DEBUG) || defined(CUDA_DEBUG) #if defined(DEBUG) || defined(CUDA_DEBUG)
cudaDeviceSynchronize(); cudaDeviceSynchronize();
#endif #endif
if (cuda_debug_sync) {
cudaDeviceSynchronize();
}
cudnnStatus_t status2 = CUDNN_STATUS_SUCCESS; cudnnStatus_t status2 = CUDNN_STATUS_SUCCESS;
#ifdef CUDNN_ERRQUERY_RAWCODE #ifdef CUDNN_ERRQUERY_RAWCODE
cudnnStatus_t status_tmp = cudnnQueryRuntimeError(cudnn_handle(), &status2, CUDNN_ERRQUERY_RAWCODE, NULL); cudnnStatus_t status_tmp = cudnnQueryRuntimeError(cudnn_handle(), &status2, CUDNN_ERRQUERY_RAWCODE, NULL);
@ -208,10 +217,13 @@ void cudnn_check_error_extended(cudnnStatus_t status, const char *file, int line
cudnn_check_error(status); cudnn_check_error(status);
} }
#if defined(DEBUG) || defined(CUDA_DEBUG) #if defined(DEBUG) || defined(CUDA_DEBUG)
status = cudaDeviceSynchronize(); cuda_debug_sync = 1;
if (status != CUDNN_STATUS_SUCCESS)
printf("\n cuDNN status = cudaDeviceSynchronize() Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time);
#endif #endif
if (cuda_debug_sync) {
cudaError_t status = cudaDeviceSynchronize();
if (status != CUDNN_STATUS_SUCCESS)
printf("\n cudaError_t status = cudaDeviceSynchronize() Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time);
}
cudnn_check_error(status); cudnn_check_error(status);
} }
#endif #endif

@ -6,6 +6,8 @@
extern "C" { extern "C" {
#endif #endif
extern int cuda_debug_sync;
extern int gpu_index; extern int gpu_index;
#ifdef __cplusplus #ifdef __cplusplus
} }

@ -165,17 +165,19 @@ void oneoff(char *cfgfile, char *weightfile, char *outfile)
copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + l.n/3*l.c, 1); copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + l.n/3*l.c, 1);
copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + 2*l.n/3*l.c, 1); copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + 2*l.n/3*l.c, 1);
*net.seen = 0; *net.seen = 0;
*net.cur_iteration = 0;
save_weights(net, outfile); save_weights(net, outfile);
} }
void partial(char *cfgfile, char *weightfile, char *outfile, int max) void partial(char *cfgfile, char *weightfile, char *outfile, int max)
{ {
gpu_index = -1; gpu_index = -1;
network net = parse_network_cfg(cfgfile); network net = parse_network_cfg_custom(cfgfile, 1, 1);
if(weightfile){ if(weightfile){
load_weights_upto(&net, weightfile, max); load_weights_upto(&net, weightfile, max);
} }
*net.seen = 0; *net.seen = 0;
*net.cur_iteration = 0;
save_weights_upto(net, outfile, max); save_weights_upto(net, outfile, max);
} }
@ -455,15 +457,22 @@ int main(int argc, char **argv)
#ifndef GPU #ifndef GPU
gpu_index = -1; gpu_index = -1;
printf(" GPU isn't used \n");
init_cpu(); init_cpu();
#else #else // GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
cuda_set_device(gpu_index); cuda_set_device(gpu_index);
CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
} }
show_cuda_cudnn_info(); show_cuda_cudnn_info();
#endif cuda_debug_sync = find_arg(argc, argv, "-cuda_debug_sync");
#ifdef CUDNN_HALF
printf(" CUDNN_HALF=1 \n");
#endif // CUDNN_HALF
#endif // GPU
show_opencv_info(); show_opencv_info();

@ -8,6 +8,8 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
extern int check_mistakes;
#define NUMCHARS 37 #define NUMCHARS 37
pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
@ -178,7 +180,6 @@ matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int
return X; return X;
} }
extern int check_mistakes;
box_label *read_boxes(char *filename, int *n) box_label *read_boxes(char *filename, int *n)
{ {
@ -950,7 +951,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
const int random_index = random_gen(); const int random_index = random_gen();
c = c ? c : 3; c = c ? c : 3;
assert(use_mixup != 2); if (use_mixup == 2) {
printf("\n cutmix=1 - isn't supported for Detector \n");
exit(0);
}
if (use_mixup == 3 && letter_box) { if (use_mixup == 3 && letter_box) {
printf("\n Combination: letter_box=1 & mosaic=1 - isn't supported, use only 1 of these parameters \n"); printf("\n Combination: letter_box=1 & mosaic=1 - isn't supported, use only 1 of these parameters \n");
exit(0); exit(0);
@ -1211,7 +1215,15 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
if(track) random_paths = get_sequential_paths(paths, n, m, mini_batch, augment_speed); if(track) random_paths = get_sequential_paths(paths, n, m, mini_batch, augment_speed);
else random_paths = get_random_paths(paths, n, m); else random_paths = get_random_paths(paths, n, m);
assert(use_mixup < 2); //assert(use_mixup < 2);
if (use_mixup == 2) {
printf("\n cutmix=1 - isn't supported for Detector \n");
exit(0);
}
if (use_mixup == 3) {
printf("\n mosaic=1 - compile Darknet with OpenCV for using mosaic=1 \n");
exit(0);
}
int mixup = use_mixup ? random_gen() % 2 : 0; int mixup = use_mixup ? random_gen() % 2 : 0;
//printf("\n mixup = %d \n", mixup); //printf("\n mixup = %d \n", mixup);
if (mixup) { if (mixup) {

@ -66,19 +66,22 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
srand(time(0)); srand(time(0));
int seed = rand(); int seed = rand();
int i; int k;
for (i = 0; i < ngpus; ++i) { for (k = 0; k < ngpus; ++k) {
srand(seed); srand(seed);
#ifdef GPU #ifdef GPU
cuda_set_device(gpus[i]); cuda_set_device(gpus[k]);
#endif #endif
nets[i] = parse_network_cfg(cfgfile); nets[k] = parse_network_cfg(cfgfile);
nets[i].benchmark_layers = benchmark_layers; nets[k].benchmark_layers = benchmark_layers;
if (weightfile) { if (weightfile) {
load_weights(&nets[i], weightfile); load_weights(&nets[k], weightfile);
} }
if (clear) *nets[i].seen = 0; if (clear) {
nets[i].learning_rate *= ngpus; *nets[k].seen = 0;
*nets[k].cur_iteration = 0;
}
nets[k].learning_rate *= ngpus;
} }
srand(time(0)); srand(time(0));
network net = nets[0]; network net = nets[0];
@ -105,12 +108,13 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int train_images_num = plist->size; int train_images_num = plist->size;
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
int init_w = net.w; const int init_w = net.w;
int init_h = net.h; const int init_h = net.h;
const int init_b = net.batch;
int iter_save, iter_save_last, iter_map; int iter_save, iter_save_last, iter_map;
iter_save = get_current_batch(net); iter_save = get_current_iteration(net);
iter_save_last = get_current_batch(net); iter_save_last = get_current_iteration(net);
iter_map = get_current_batch(net); iter_map = get_current_iteration(net);
float mean_average_precision = -1; float mean_average_precision = -1;
float best_map = mean_average_precision; float best_map = mean_average_precision;
@ -169,7 +173,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
start = what_time_is_it_now(); start = what_time_is_it_now();
//while(i*imgs < N*120){ //while(i*imgs < N*120){
while (get_current_batch(net) < net.max_batches) { while (get_current_iteration(net) < net.max_batches) {
if (l.random && count++ % 10 == 0) { if (l.random && count++ % 10 == 0) {
float rand_coef = 1.4; float rand_coef = 1.4;
if (l.random != 1.0) rand_coef = l.random; if (l.random != 1.0) rand_coef = l.random;
@ -179,26 +183,48 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int dim_h = roundl(random_val*init_h / net.resize_step + 1) * net.resize_step; int dim_h = roundl(random_val*init_h / net.resize_step + 1) * net.resize_step;
if (random_val < 1 && (dim_w > init_w || dim_h > init_h)) dim_w = init_w, dim_h = init_h; if (random_val < 1 && (dim_w > init_w || dim_h > init_h)) dim_w = init_w, dim_h = init_h;
// at the beginning int max_dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
if (avg_loss < 0) { int max_dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step;
dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step; // at the beginning (check if enough memory) and at the end (calc rolling mean/variance)
if (avg_loss < 0 || get_current_iteration(net) > net.max_batches - 100) {
dim_w = max_dim_w;
dim_h = max_dim_h;
} }
if (dim_w < net.resize_step) dim_w = net.resize_step; if (dim_w < net.resize_step) dim_w = net.resize_step;
if (dim_h < net.resize_step) dim_h = net.resize_step; if (dim_h < net.resize_step) dim_h = net.resize_step;
int dim_b = (init_b * max_dim_w * max_dim_h) / (dim_w * dim_h);
int new_dim_b = (int)(dim_b * 0.8);
if (new_dim_b > init_b) dim_b = new_dim_b;
printf("%d x %d \n", dim_w, dim_h);
args.w = dim_w; args.w = dim_w;
args.h = dim_h; args.h = dim_h;
int k;
if (net.dynamic_minibatch) {
for (k = 0; k < ngpus; ++k) {
(*nets[k].seen) = init_b * net.subdivisions * get_current_iteration(net); // remove this line, when you will save to weights-file both: seen & cur_iteration
nets[k].batch = dim_b;
int j;
for (j = 0; j < nets[k].n; ++j)
nets[k].layers[j].batch = dim_b;
}
net.batch = dim_b;
imgs = net.batch * net.subdivisions * ngpus;
args.n = imgs;
printf("\n %d x %d (batch = %d) \n", dim_w, dim_h, net.batch);
}
else
printf("\n %d x %d \n", dim_w, dim_h);
pthread_join(load_thread, 0); pthread_join(load_thread, 0);
train = buffer; train = buffer;
free_data(train); free_data(train);
load_thread = load_data(args); load_thread = load_data(args);
for (i = 0; i < ngpus; ++i) { for (k = 0; k < ngpus; ++k) {
resize_network(nets + i, dim_w, dim_h); resize_network(nets + k, dim_w, dim_h);
} }
net = nets[0]; net = nets[0];
} }
@ -250,7 +276,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
if (avg_loss < 0 || avg_loss != avg_loss) avg_loss = loss; // if(-inf or nan) if (avg_loss < 0 || avg_loss != avg_loss) avg_loss = loss; // if(-inf or nan)
avg_loss = avg_loss*.9 + loss*.1; avg_loss = avg_loss*.9 + loss*.1;
i = get_current_batch(net); const int iteration = get_current_iteration(net);
//i = get_current_batch(net);
int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
calc_map_for_each = fmax(calc_map_for_each, 100); calc_map_for_each = fmax(calc_map_for_each, 100);
@ -263,22 +290,36 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
} }
if (net.cudnn_half) { if (net.cudnn_half) {
if (i < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in); if (iteration < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
else fprintf(stderr, "\n Tensor Cores are used."); else fprintf(stderr, "\n Tensor Cores are used.");
} }
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), i*imgs); printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", iteration, loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), iteration*imgs);
int draw_precision = 0; int draw_precision = 0;
if (calc_map && (i >= next_map_calc || i == net.max_batches)) { if (calc_map && (iteration >= next_map_calc || iteration == net.max_batches)) {
if (l.random) { if (l.random) {
printf("Resizing to initial size: %d x %d \n", init_w, init_h); printf("Resizing to initial size: %d x %d ", init_w, init_h);
args.w = init_w; args.w = init_w;
args.h = init_h; args.h = init_h;
int k;
if (net.dynamic_minibatch) {
for (k = 0; k < ngpus; ++k) {
for (k = 0; k < ngpus; ++k) {
nets[k].batch = init_b;
int j;
for (j = 0; j < nets[k].n; ++j)
nets[k].layers[j].batch = init_b;
}
}
net.batch = init_b;
imgs = init_b * net.subdivisions * ngpus;
args.n = imgs;
printf("\n %d x %d (batch = %d) \n", init_w, init_h, init_b);
}
pthread_join(load_thread, 0); pthread_join(load_thread, 0);
free_data(train); free_data(train);
train = buffer; train = buffer;
load_thread = load_data(args); load_thread = load_data(args);
int k;
for (k = 0; k < ngpus; ++k) { for (k = 0; k < ngpus; ++k) {
resize_network(nets + k, init_w, init_h); resize_network(nets + k, init_w, init_h);
} }
@ -290,7 +331,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
// combine Training and Validation networks // combine Training and Validation networks
//network net_combined = combine_train_valid_networks(net, net_map); //network net_combined = combine_train_valid_networks(net, net_map);
iter_map = i; iter_map = iteration;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined); mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision); printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
if (mean_average_precision > best_map) { if (mean_average_precision > best_map) {
@ -319,18 +360,18 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) { //if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) {
//if (i % 100 == 0) { //if (i % 100 == 0) {
if (i >= (iter_save + 1000) || i % 1000 == 0) { if (iteration >= (iter_save + 1000) || iteration % 1000 == 0) {
iter_save = i; iter_save = iteration;
#ifdef GPU #ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0); if (ngpus != 1) sync_nets(nets, ngpus, 0);
#endif #endif
char buff[256]; char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); sprintf(buff, "%s/%s_%d.weights", backup_directory, base, iteration);
save_weights(net, buff); save_weights(net, buff);
} }
if (i >= (iter_save_last + 100) || i % 100 == 0) { if (iteration >= (iter_save_last + 100) || (iteration % 100 == 0 && iteration > 1)) {
iter_save_last = i; iter_save_last = iteration;
#ifdef GPU #ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0); if (ngpus != 1) sync_nets(nets, ngpus, 0);
#endif #endif
@ -364,7 +405,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
free_list_contents_kvp(options); free_list_contents_kvp(options);
free_list(options); free_list(options);
for (i = 0; i < ngpus; ++i) free_network(nets[i]); for (k = 0; k < ngpus; ++k) free_network(nets[k]);
free(nets); free(nets);
//free_network(net); //free_network(net);

@ -33,8 +33,10 @@ dropout_layer make_dropout_layer(int batch, int inputs, float probability, int d
l.forward_gpu = forward_dropout_layer_gpu; l.forward_gpu = forward_dropout_layer_gpu;
l.backward_gpu = backward_dropout_layer_gpu; l.backward_gpu = backward_dropout_layer_gpu;
l.rand_gpu = cuda_make_array(l.rand, inputs*batch); l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
if (l.dropblock) {
l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch); l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch);
l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch); l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch);
}
#endif #endif
if (l.dropblock) { if (l.dropblock) {
if(l.dropblock_size_abs) fprintf(stderr, "dropblock p = %.3f l.dropblock_size_abs = %d %4d -> %4d\n", probability, l.dropblock_size_abs, inputs, inputs); if(l.dropblock_size_abs) fprintf(stderr, "dropblock p = %.3f l.dropblock_size_abs = %d %4d -> %4d\n", probability, l.dropblock_size_abs, inputs, inputs);
@ -50,8 +52,15 @@ void resize_dropout_layer(dropout_layer *l, int inputs)
l->rand = (float*)xrealloc(l->rand, l->inputs * l->batch * sizeof(float)); l->rand = (float*)xrealloc(l->rand, l->inputs * l->batch * sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free(l->rand_gpu); cuda_free(l->rand_gpu);
l->rand_gpu = cuda_make_array(l->rand, l->inputs*l->batch); l->rand_gpu = cuda_make_array(l->rand, l->inputs*l->batch);
if (l->dropblock) {
cudaFreeHost(l->drop_blocks_scale);
l->drop_blocks_scale = cuda_make_array_pinned(l->rand, l->batch);
cuda_free(l->drop_blocks_scale_gpu);
l->drop_blocks_scale_gpu = cuda_make_array(l->rand, l->batch);
}
#endif #endif
} }

@ -74,6 +74,15 @@ __global__ void dropblock_fast_kernel(float *rand, float prob, int w, int h, int
} }
__global__ void set_scales_dropblock_kernel(float *drop_blocks_scale, int block_size_w, int block_size_h, int outputs, int batch)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index >= batch) return;
const float prob = drop_blocks_scale[index] * block_size_w * block_size_h / (float)outputs;
const float scale = 1.0f / (1.0f - prob);
drop_blocks_scale[index] = scale;
}
__global__ void scale_dropblock_kernel(float *output, int size, int outputs, float *drop_blocks_scale) __global__ void scale_dropblock_kernel(float *output, int size, int outputs, float *drop_blocks_scale)
{ {
@ -95,7 +104,7 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
void forward_dropout_layer_gpu(dropout_layer l, network_state state) void forward_dropout_layer_gpu(dropout_layer l, network_state state)
{ {
if (!state.train) return; if (!state.train) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return; //if (iteration_num < state.net.burn_in) return;
// We gradually increase the block size and the probability of dropout - during the first half of the training // We gradually increase the block size and the probability of dropout - during the first half of the training
@ -136,18 +145,24 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
dropblock_fast_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (l.rand_gpu, block_prob, l.w, l.h, l.w*l.h, l.c, block_size, l.drop_blocks_scale_gpu, state.input); dropblock_fast_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (l.rand_gpu, block_prob, l.w, l.h, l.w*l.h, l.c, block_size, l.drop_blocks_scale_gpu, state.input);
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
num_blocks = get_number_of_blocks(l.batch, BLOCK);
set_scales_dropblock_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (l.drop_blocks_scale_gpu, block_size, block_size, l.outputs, l.batch);
CHECK_CUDA(cudaPeekAtLastError());
/*
cuda_pull_array(l.drop_blocks_scale_gpu, l.drop_blocks_scale, l.batch); cuda_pull_array(l.drop_blocks_scale_gpu, l.drop_blocks_scale, l.batch);
for (int b = 0; b < l.batch; ++b) { for (int b = 0; b < l.batch; ++b) {
const float prob = l.drop_blocks_scale[b] * block_size * block_size / (float)l.outputs; const float prob = l.drop_blocks_scale[b] * block_size * block_size / (float)l.outputs;
const float scale = 1.0f / (1.0f - prob); const float scale = 1.0f / (1.0f - prob);
printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size); //printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size);
printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n", //printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n",
l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale); // l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale);
l.drop_blocks_scale[b] = scale; l.drop_blocks_scale[b] = scale;
} }
cuda_push_array(l.drop_blocks_scale_gpu, l.drop_blocks_scale, l.batch); cuda_push_array(l.drop_blocks_scale_gpu, l.drop_blocks_scale, l.batch);
*/
num_blocks = get_number_of_blocks(l.outputs * l.batch, BLOCK); num_blocks = get_number_of_blocks(l.outputs * l.batch, BLOCK);
scale_dropblock_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (state.input, l.outputs * l.batch, l.outputs, l.drop_blocks_scale_gpu); scale_dropblock_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (state.input, l.outputs * l.batch, l.outputs, l.drop_blocks_scale_gpu);
@ -176,14 +191,14 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
void backward_dropout_layer_gpu(dropout_layer l, network_state state) void backward_dropout_layer_gpu(dropout_layer l, network_state state)
{ {
if(!state.delta) return; if(!state.delta) return;
//int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); //int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return; //if (iteration_num < state.net.burn_in) return;
int size = l.inputs*l.batch; int size = l.inputs*l.batch;
// dropblock // dropblock
if (l.dropblock) { if (l.dropblock) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions); int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
float multiplier = 1.0; float multiplier = 1.0;
if (iteration_num < (state.net.max_batches*0.85)) if (iteration_num < (state.net.max_batches*0.85))
multiplier = (iteration_num / (float)(state.net.max_batches*0.85)); multiplier = (iteration_num / (float)(state.net.max_batches*0.85));

@ -371,13 +371,20 @@ void delta_gaussian_yolo_class(float *output, float *delta, int index, int class
{ {
int n; int n;
if (delta[index]){ if (delta[index]){
delta[index + stride*class_id] = (1 - label_smooth_eps) - output[index + stride*class_id]; float y_true = 1;
if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*class_id] = y_true - output[index + stride*class_id];
//delta[index + stride*class_id] = 1 - output[index + stride*class_id];
if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(avg_cat) *avg_cat += output[index + stride*class_id]; if(avg_cat) *avg_cat += output[index + stride*class_id];
return; return;
} }
for(n = 0; n < classes; ++n){ for(n = 0; n < classes; ++n){
delta[index + stride*n] = ((n == class_id) ? (1 - label_smooth_eps) : (0 + label_smooth_eps/classes)) - output[index + stride*n]; float y_true = ((n == class_id) ? 1 : 0);
if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*n] = y_true - output[index + stride*n];
if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(n == class_id && avg_cat) *avg_cat += output[index + stride*n]; if(n == class_id && avg_cat) *avg_cat += output[index + stride*n];
} }

@ -206,14 +206,14 @@ void push_gru_layer(layer l)
{ {
} }
void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
update_connected_layer_gpu(*(l.input_r_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.input_r_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.input_z_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.input_z_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.input_h_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.input_h_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.state_r_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.state_r_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.state_z_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.state_z_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.state_h_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.state_h_layer), batch, learning_rate, momentum, decay, loss_scale);
} }
void forward_gru_layer_gpu(layer l, network_state state) void forward_gru_layer_gpu(layer l, network_state state)

@ -18,7 +18,7 @@ void update_gru_layer(layer l, int batch, float learning_rate, float momentum, f
#ifdef GPU #ifdef GPU
void forward_gru_layer_gpu(layer l, network_state state); void forward_gru_layer_gpu(layer l, network_state state);
void backward_gru_layer_gpu(layer l, network_state state); void backward_gru_layer_gpu(layer l, network_state state);
void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_gru_layer(layer l); void push_gru_layer(layer l);
void pull_gru_layer(layer l); void pull_gru_layer(layer l);
#endif #endif

@ -258,10 +258,11 @@ public:
//_write(s, head, 0); //_write(s, head, 0);
if (!close_all_sockets) _write(s, ", \n", 0); if (!close_all_sockets) _write(s, ", \n", 0);
int n = _write(s, outputbuf, outlen); int n = _write(s, outputbuf, outlen);
if (n < outlen) if (n < (int)outlen)
{ {
cerr << "JSON_sender: kill client " << s << endl; cerr << "JSON_sender: kill client " << s << endl;
::shutdown(s, 2); close_socket(s);
//::shutdown(s, 2);
FD_CLR(s, &master); FD_CLR(s, &master);
} }
@ -448,7 +449,7 @@ public:
cv::imencode(".jpg", frame, outbuf, params); //REMOVED FOR COMPATIBILITY cv::imencode(".jpg", frame, outbuf, params); //REMOVED FOR COMPATIBILITY
// https://docs.opencv.org/3.4/d4/da8/group__imgcodecs.html#ga292d81be8d76901bff7988d18d2b42ac // https://docs.opencv.org/3.4/d4/da8/group__imgcodecs.html#ga292d81be8d76901bff7988d18d2b42ac
//std::cerr << "cv::imencode call disabled!" << std::endl; //std::cerr << "cv::imencode call disabled!" << std::endl;
size_t outlen = outbuf.size(); int outlen = static_cast<int>(outbuf.size());
#ifdef _WIN32 #ifdef _WIN32
for (unsigned i = 0; i<rread.fd_count; i++) for (unsigned i = 0; i<rread.fd_count; i++)
@ -504,11 +505,12 @@ public:
sprintf(head, "--mjpegstream\r\nContent-Type: image/jpeg\r\nContent-Length: %zu\r\n\r\n", outlen); sprintf(head, "--mjpegstream\r\nContent-Type: image/jpeg\r\nContent-Length: %zu\r\n\r\n", outlen);
_write(s, head, 0); _write(s, head, 0);
int n = _write(s, (char*)(&outbuf[0]), outlen); int n = _write(s, (char*)(&outbuf[0]), outlen);
//cerr << "known client " << s << " " << n << endl; cerr << "known client: " << s << ", sent = " << n << ", must be sent outlen = " << outlen << endl;
if (n < outlen) if (n < (int)outlen)
{ {
cerr << "MJPG_sender: kill client " << s << endl; cerr << "MJPG_sender: kill client " << s << endl;
::shutdown(s, 2); //::shutdown(s, 2);
close_socket(s);
FD_CLR(s, &master); FD_CLR(s, &master);
} }
} }

@ -1344,7 +1344,7 @@ void show_opencv_info()
#else // OPENCV #else // OPENCV
extern "C" void show_opencv_info() extern "C" void show_opencv_info()
{ {
std::cerr << " OpenCV isn't used \n"; std::cerr << " OpenCV isn't used - data increase will run slowly \n";
} }
extern "C" int wait_key_cv(int delay) { return 0; } extern "C" int wait_key_cv(int delay) { return 0; }
extern "C" int wait_until_press_key_cv() { return 0; } extern "C" int wait_until_press_key_cv() { return 0; }

@ -253,7 +253,7 @@ void backward_local_layer_gpu(local_layer l, network_state state)
} }
} }
void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float momentum, float decay) void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
int locations = l.out_w*l.out_h; int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations; int size = l.size*l.size*l.c*l.n*locations;

@ -15,7 +15,7 @@ extern "C" {
#ifdef GPU #ifdef GPU
void forward_local_layer_gpu(local_layer layer, network_state state); void forward_local_layer_gpu(local_layer layer, network_state state);
void backward_local_layer_gpu(local_layer layer, network_state state); void backward_local_layer_gpu(local_layer layer, network_state state);
void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay); void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_local_layer(local_layer layer); void push_local_layer(local_layer layer);
void pull_local_layer(local_layer layer); void pull_local_layer(local_layer layer);

@ -401,16 +401,16 @@ void backward_lstm_layer(layer l, network_state state)
} }
#ifdef GPU #ifdef GPU
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
update_connected_layer_gpu(*(l.wf), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.wf), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.wi), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.wi), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.wg), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.wg), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.wo), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.wo), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.uf), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.uf), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.ui), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.ui), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.ug), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.ug), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.uo), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.uo), batch, learning_rate, momentum, decay, loss_scale);
} }
void forward_lstm_layer_gpu(layer l, network_state state) void forward_lstm_layer_gpu(layer l, network_state state)

@ -18,7 +18,7 @@ void update_lstm_layer(layer l, int batch, float learning_rate, float momentum,
#ifdef GPU #ifdef GPU
void forward_lstm_layer_gpu(layer l, network_state state); void forward_lstm_layer_gpu(layer l, network_state state);
void backward_lstm_layer_gpu(layer l, network_state state); void backward_lstm_layer_gpu(layer l, network_state state);
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus

@ -57,6 +57,11 @@ load_args get_base_args(network *net)
return args; return args;
} }
int64_t get_current_iteration(network net)
{
return *net.cur_iteration;
}
int get_current_batch(network net) int get_current_batch(network net)
{ {
int batch_num = (*net.seen)/(net.batch*net.subdivisions); int batch_num = (*net.seen)/(net.batch*net.subdivisions);
@ -240,6 +245,7 @@ network make_network(int n)
net.n = n; net.n = n;
net.layers = (layer*)xcalloc(net.n, sizeof(layer)); net.layers = (layer*)xcalloc(net.n, sizeof(layer));
net.seen = (uint64_t*)xcalloc(1, sizeof(uint64_t)); net.seen = (uint64_t*)xcalloc(1, sizeof(uint64_t));
net.cur_iteration = (int*)xcalloc(1, sizeof(int));
#ifdef GPU #ifdef GPU
net.input_gpu = (float**)xcalloc(1, sizeof(float*)); net.input_gpu = (float**)xcalloc(1, sizeof(float*));
net.truth_gpu = (float**)xcalloc(1, sizeof(float*)); net.truth_gpu = (float**)xcalloc(1, sizeof(float*));
@ -359,7 +365,7 @@ float train_network_datum(network net, float *x, float *y)
forward_network(net, state); forward_network(net, state);
backward_network(net, state); backward_network(net, state);
float error = get_network_cost(net); float error = get_network_cost(net);
if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net); //if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net);
return error; return error;
} }
@ -404,6 +410,12 @@ float train_network_waitkey(network net, data d, int wait_key)
sum += err; sum += err;
if(wait_key) wait_key_cv(5); if(wait_key) wait_key_cv(5);
} }
(*net.cur_iteration) += 1;
#ifdef GPU
update_network_gpu(net);
#else // GPU
update_network(net);
#endif // GPU
free(X); free(X);
free(y); free(y);
return (float)sum/(n*batch); return (float)sum/(n*batch);
@ -523,7 +535,7 @@ int resize_network(network *net, int w, int h)
//fflush(stderr); //fflush(stderr);
for (i = 0; i < net->n; ++i){ for (i = 0; i < net->n; ++i){
layer l = net->layers[i]; layer l = net->layers[i];
//printf(" %d: layer = %d,", i, l.type); //printf(" (resize %d: layer = %d) , ", i, l.type);
if(l.type == CONVOLUTIONAL){ if(l.type == CONVOLUTIONAL){
resize_convolutional_layer(&l, w, h); resize_convolutional_layer(&l, w, h);
} }
@ -1048,6 +1060,7 @@ void free_network(network net)
free(net.scales); free(net.scales);
free(net.steps); free(net.steps);
free(net.seen); free(net.seen);
free(net.cur_iteration);
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) cuda_free(net.workspace); if (gpu_index >= 0) cuda_free(net.workspace);
@ -1120,8 +1133,9 @@ void fuse_conv_batchnorm(network net)
{ {
if (l->nweights > 0) { if (l->nweights > 0) {
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights); //cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
for (int i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]); int i;
printf(" l->nweights = %d \n", l->nweights); for (i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]);
printf(" l->nweights = %d, j = %d \n", l->nweights, j);
} }
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w) // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)

@ -108,6 +108,7 @@ float get_current_seq_subdivisions(network net);
int get_sequence_value(network net); int get_sequence_value(network net);
float get_current_rate(network net); float get_current_rate(network net);
int get_current_batch(network net); int get_current_batch(network net);
int64_t get_current_iteration(network net);
//void free_network(network net); // darknet.h //void free_network(network net); // darknet.h
void compare_networks(network n1, network n2, data d); void compare_networks(network n1, network n2, data d);
char *get_layer_string(LAYER_TYPE a); char *get_layer_string(LAYER_TYPE a);

@ -169,7 +169,8 @@ void backward_network_gpu(network net, network_state state)
for(i = net.n-1; i >= 0; --i){ for(i = net.n-1; i >= 0; --i){
state.index = i; state.index = i;
layer l = net.layers[i]; layer l = net.layers[i];
if (l.stopbackward) break; if (l.stopbackward == 1) break;
if (l.stopbackward > get_current_iteration(net)) break;
if(i == 0){ if(i == 0){
state.input = original_input; state.input = original_input;
state.delta = original_delta; state.delta = original_delta;
@ -247,8 +248,9 @@ void update_network_gpu(network net)
layer l = net.layers[i]; layer l = net.layers[i];
l.t = get_current_batch(net); l.t = get_current_batch(net);
if (iteration_num > (net.max_batches * 1 / 2)) l.deform = 0; if (iteration_num > (net.max_batches * 1 / 2)) l.deform = 0;
if(l.update_gpu){ if (l.burnin_update && (l.burnin_update*net.burn_in > iteration_num)) continue;
l.update_gpu(l, update_batch, rate, net.momentum, net.decay); if(l.update_gpu && l.dont_update < iteration_num){
l.update_gpu(l, update_batch, rate, net.momentum, net.decay, net.loss_scale);
} }
} }
} }
@ -318,7 +320,7 @@ float train_network_datum_gpu(network net, float *x, float *y)
float error = get_network_cost(net); float error = get_network_cost(net);
//if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net); //if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
const int sequence = get_sequence_value(net); const int sequence = get_sequence_value(net);
if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net); //if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net);
return error; return error;
} }
@ -379,7 +381,7 @@ void update_layer(layer l, network net)
float rate = get_current_rate(net); float rate = get_current_rate(net);
l.t = get_current_batch(net); l.t = get_current_batch(net);
if(l.update_gpu){ if(l.update_gpu){
l.update_gpu(l, update_batch, rate, net.momentum, net.decay); l.update_gpu(l, update_batch, rate, net.momentum, net.decay, net.loss_scale);
} }
} }
@ -564,7 +566,9 @@ float train_networks(network *nets, int n, data d, int interval)
sum += errors[i]; sum += errors[i];
} }
//cudaDeviceSynchronize(); //cudaDeviceSynchronize();
if (get_current_batch(nets[0]) % interval == 0) { *nets[0].cur_iteration += (n - 1);
if (get_current_iteration(nets[0]) % interval == 0)
{
printf("Syncing... "); printf("Syncing... ");
fflush(stdout); fflush(stdout);
sync_nets(nets, n, interval); sync_nets(nets, n, interval);

@ -997,6 +997,7 @@ route_layer parse_route(list *options, size_params params)
if(next.out_w == first.out_w && next.out_h == first.out_h){ if(next.out_w == first.out_w && next.out_h == first.out_h){
layer.out_c += next.out_c; layer.out_c += next.out_c;
}else{ }else{
fprintf(stderr, " The width and height of the input layers are different. \n");
layer.out_h = layer.out_w = layer.out_c = 0; layer.out_h = layer.out_w = layer.out_c = 0;
} }
} }
@ -1053,6 +1054,10 @@ void parse_net_options(list *options, network *net)
net->batch *= net->time_steps; net->batch *= net->time_steps;
net->subdivisions = subdivs; net->subdivisions = subdivs;
*net->seen = 0;
*net->cur_iteration = 0;
net->loss_scale = option_find_float_quiet(options, "loss_scale", 1);
net->dynamic_minibatch = option_find_int_quiet(options, "dynamic_minibatch", 0);
net->optimized_memory = option_find_int_quiet(options, "optimized_memory", 0); net->optimized_memory = option_find_int_quiet(options, "optimized_memory", 0);
net->workspace_size_limit = (size_t)1024*1024 * option_find_float_quiet(options, "workspace_size_limit_MB", 1024); // 1024 MB by default net->workspace_size_limit = (size_t)1024*1024 * option_find_float_quiet(options, "workspace_size_limit_MB", 1024); // 1024 MB by default
@ -1093,15 +1098,17 @@ void parse_net_options(list *options, network *net)
char *policy_s = option_find_str(options, "policy", "constant"); char *policy_s = option_find_str(options, "policy", "constant");
net->policy = get_policy(policy_s); net->policy = get_policy(policy_s);
net->burn_in = option_find_int_quiet(options, "burn_in", 0); net->burn_in = option_find_int_quiet(options, "burn_in", 0);
#ifdef CUDNN_HALF #ifdef GPU
if (net->gpu_index >= 0) { if (net->gpu_index >= 0) {
int compute_capability = get_gpu_compute_capability(net->gpu_index); int compute_capability = get_gpu_compute_capability(net->gpu_index);
if (get_gpu_compute_capability(net->gpu_index) >= 700) net->cudnn_half = 1; #ifdef CUDNN_HALF
if (compute_capability >= 700) net->cudnn_half = 1;
else net->cudnn_half = 0; else net->cudnn_half = 0;
#endif// CUDNN_HALF
fprintf(stderr, " compute_capability = %d, cudnn_half = %d \n", compute_capability, net->cudnn_half); fprintf(stderr, " compute_capability = %d, cudnn_half = %d \n", compute_capability, net->cudnn_half);
} }
else fprintf(stderr, " GPU isn't used \n"); else fprintf(stderr, " GPU isn't used \n");
#endif #endif// GPU
if(net->policy == STEP){ if(net->policy == STEP){
net->step = option_find_int(options, "step", 1); net->step = option_find_int(options, "step", 1);
net->scale = option_find_float(options, "scale", 1); net->scale = option_find_float(options, "scale", 1);
@ -1201,7 +1208,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
params.batch = net.batch; params.batch = net.batch;
params.time_steps = net.time_steps; params.time_steps = net.time_steps;
params.net = net; params.net = net;
printf("batch = %d, time_steps = %d, train = %d \n", net.batch, net.time_steps, params.train); printf("mini_batch = %d, batch = %d, time_steps = %d, train = %d \n", net.batch, net.batch * net.subdivisions, net.time_steps, params.train);
int avg_outputs = 0; int avg_outputs = 0;
float bflops = 0; float bflops = 0;
@ -1357,7 +1364,11 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
} }
#endif // GPU #endif // GPU
l.clip = option_find_float_quiet(options, "clip", 0);
l.dynamic_minibatch = net.dynamic_minibatch;
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0); l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
l.dont_update = option_find_int_quiet(options, "dont_update", 0);
l.burnin_update = option_find_int_quiet(options, "burnin_update", 0);
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0); l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
l.dontload = option_find_int_quiet(options, "dontload", 0); l.dontload = option_find_int_quiet(options, "dontload", 0);
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0); l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
@ -1549,8 +1560,15 @@ void save_shortcut_weights(layer l, FILE *fp)
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) { if (gpu_index >= 0) {
pull_shortcut_layer(l); pull_shortcut_layer(l);
printf("\n pull_shortcut_layer \n");
} }
#endif #endif
int i;
for (i = 0; i < l.nweights; ++i) printf(" %f, ", l.weight_updates[i]);
printf(" l.nweights = %d - update \n", l.nweights);
for (i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
printf(" l.nweights = %d \n\n", l.nweights);
int num = l.nweights; int num = l.nweights;
fwrite(l.weights, sizeof(float), num, fp); fwrite(l.weights, sizeof(float), num, fp);
} }
@ -1626,6 +1644,7 @@ void save_weights_upto(network net, char *filename, int cutoff)
fwrite(&major, sizeof(int), 1, fp); fwrite(&major, sizeof(int), 1, fp);
fwrite(&minor, sizeof(int), 1, fp); fwrite(&minor, sizeof(int), 1, fp);
fwrite(&revision, sizeof(int), 1, fp); fwrite(&revision, sizeof(int), 1, fp);
//(*net.seen) = (*net.cur_iteration) * net.batch * net.subdivisions;
fwrite(net.seen, sizeof(uint64_t), 1, fp); fwrite(net.seen, sizeof(uint64_t), 1, fp);
int i; int i;
@ -1835,7 +1854,7 @@ void load_shortcut_weights(layer l, FILE *fp)
read_bytes = fread(l.weights, sizeof(float), num, fp); read_bytes = fread(l.weights, sizeof(float), num, fp);
if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index); if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index);
//for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]); //for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
//printf("\n\n"); //printf(" read_bytes = %d \n\n", read_bytes);
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) { if (gpu_index >= 0) {
push_shortcut_layer(l); push_shortcut_layer(l);
@ -1873,6 +1892,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
fread(&iseen, sizeof(uint32_t), 1, fp); fread(&iseen, sizeof(uint32_t), 1, fp);
*net->seen = iseen; *net->seen = iseen;
} }
*net->cur_iteration = get_current_batch(*net);
printf(", trained: %.0f K-images (%.0f Kilo-batches_64) \n", (float)(*net->seen / 1000), (float)(*net->seen / 64000)); printf(", trained: %.0f K-images (%.0f Kilo-batches_64) \n", (float)(*net->seen / 1000), (float)(*net->seen / 64000));
int transpose = (major > 1000) || (minor > 1000); int transpose = (major > 1000) || (minor > 1000);
@ -1968,7 +1988,10 @@ network *load_network_custom(char *cfg, char *weights, int clear, int batch)
load_weights(net, weights); load_weights(net, weights);
} }
fuse_conv_batchnorm(*net); fuse_conv_batchnorm(*net);
if (clear) (*net->seen) = 0; if (clear) {
(*net->seen) = 0;
(*net->cur_iteration) = 0;
}
return net; return net;
} }
@ -1982,6 +2005,9 @@ network *load_network(char *cfg, char *weights, int clear)
printf(" Try to load weights: %s \n", weights); printf(" Try to load weights: %s \n", weights);
load_weights(net, weights); load_weights(net, weights);
} }
if (clear) (*net->seen) = 0; if (clear) {
(*net->seen) = 0;
(*net->cur_iteration) = 0;
}
return net; return net;
} }

@ -67,7 +67,8 @@ void resize_region_layer(layer *l, int w, int h)
l->delta = (float*)xrealloc(l->delta, l->batch * l->outputs * sizeof(float)); l->delta = (float*)xrealloc(l->delta, l->batch * l->outputs * sizeof(float));
#ifdef GPU #ifdef GPU
if (old_w < w || old_h < h) { //if (old_w < w || old_h < h)
{
cuda_free(l->delta_gpu); cuda_free(l->delta_gpu);
cuda_free(l->output_gpu); cuda_free(l->output_gpu);

@ -155,7 +155,10 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int batch = net.batch; int batch = net.batch;
int steps = net.time_steps; int steps = net.time_steps;
if(clear) *net.seen = 0; if (clear) {
*net.seen = 0;
*net.cur_iteration = 0;
}
int i = (*net.seen)/net.batch; int i = (*net.seen)/net.batch;
int streams = batch/steps; int streams = batch/steps;

@ -196,11 +196,11 @@ void push_rnn_layer(layer l)
push_connected_layer(*(l.output_layer)); push_connected_layer(*(l.output_layer));
} }
void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale)
{ {
update_connected_layer_gpu(*(l.input_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.input_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.self_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.self_layer), batch, learning_rate, momentum, decay, loss_scale);
update_connected_layer_gpu(*(l.output_layer), batch, learning_rate, momentum, decay); update_connected_layer_gpu(*(l.output_layer), batch, learning_rate, momentum, decay, loss_scale);
} }
void forward_rnn_layer_gpu(layer l, network_state state) void forward_rnn_layer_gpu(layer l, network_state state)

@ -19,7 +19,7 @@ void update_rnn_layer(layer l, int batch, float learning_rate, float momentum, f
#ifdef GPU #ifdef GPU
void forward_rnn_layer_gpu(layer l, network_state state); void forward_rnn_layer_gpu(layer l, network_state state);
void backward_rnn_layer_gpu(layer l, network_state state); void backward_rnn_layer_gpu(layer l, network_state state);
void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay, float loss_scale);
void push_rnn_layer(layer l); void push_rnn_layer(layer l);
void pull_rnn_layer(layer l); void pull_rnn_layer(layer l);
#endif #endif

@ -53,7 +53,7 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
if (l.nweights > 0) { if (l.nweights > 0) {
l.weights = (float*)calloc(l.nweights, sizeof(float)); l.weights = (float*)calloc(l.nweights, sizeof(float));
float scale = sqrt(2. / l.nweights); float scale = sqrt(2. / l.nweights);
for (i = 0; i < l.nweights; ++i) l.weights[i] = 1 + 0.01*rand_uniform(-1, 1);// scale*rand_uniform(-1, 1); // rand_normal(); for (i = 0; i < l.nweights; ++i) l.weights[i] = 1;// +0.01*rand_uniform(-1, 1);// scale*rand_uniform(-1, 1); // rand_normal();
if (train) l.weight_updates = (float*)calloc(l.nweights, sizeof(float)); if (train) l.weight_updates = (float*)calloc(l.nweights, sizeof(float));
l.update = update_shortcut_layer; l.update = update_shortcut_layer;
@ -240,7 +240,7 @@ void backward_shortcut_layer_gpu(const layer l, network_state state)
//shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu); //shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu);
} }
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay) void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale)
{ {
if (l.nweights > 0) { if (l.nweights > 0) {
float learning_rate = learning_rate_init*l.learning_rate_scale; float learning_rate = learning_rate_init*l.learning_rate_scale;
@ -248,21 +248,43 @@ void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, flo
//float decay = a.decay; //float decay = a.decay;
//int batch = a.batch; //int batch = a.batch;
// Loss scale for Mixed-Precision on Tensor-Cores
if (loss_scale != 1.0) {
if(l.weight_updates_gpu && l.nweights > 0) scal_ongpu(l.nweights, 1.0 / loss_scale, l.weight_updates_gpu, 1);
}
reset_nan_and_inf(l.weight_updates_gpu, l.nweights); reset_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights); fix_nan_and_inf(l.weights_gpu, l.nweights);
axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); //constrain_weight_updates_ongpu(l.nweights, 1, l.weights_gpu, l.weight_updates_gpu);
constrain_ongpu(l.nweights, 1, l.weight_updates_gpu, 1);
/*
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights);
CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream()));
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weight_updates[i]);
printf(" l.nweights = %d - updates \n", l.nweights);
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
printf(" l.nweights = %d \n\n", l.nweights);
*/
//axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1); scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
//fill_ongpu(l.nweights, 0, l.weight_updates_gpu, 1);
//if (l.clip) { //if (l.clip) {
// constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1); // constrain_ongpu(l.nweights, l.clip, l.weights_gpu, 1);
//} //}
} }
} }
void pull_shortcut_layer(layer l) void pull_shortcut_layer(layer l)
{ {
constrain_ongpu(l.nweights, 1, l.weight_updates_gpu, 1);
cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights);
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights); cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream())); CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream()));

@ -18,7 +18,7 @@ void resize_shortcut_layer(layer *l, int w, int h, network *net);
#ifdef GPU #ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network_state state); void forward_shortcut_layer_gpu(const layer l, network_state state);
void backward_shortcut_layer_gpu(const layer l, network_state state); void backward_shortcut_layer_gpu(const layer l, network_state state);
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay); void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale);
void pull_shortcut_layer(layer l); void pull_shortcut_layer(layer l);
void push_shortcut_layer(layer l); void push_shortcut_layer(layer l);
#endif #endif

@ -13,7 +13,10 @@ void train_tag(char *cfgfile, char *weightfile, int clear)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
if(clear) *net.seen = 0; if (clear) {
*net.seen = 0;
*net.cur_iteration = 0;
}
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024; int imgs = 1024;
list* plist = get_paths("tag/train.list"); list* plist = get_paths("tag/train.list");

@ -151,11 +151,13 @@ std::vector<bbox_t> get_3d_coordinates(std::vector<bbox_t> bbox_vect, cv::Mat xy
#ifndef USE_CMAKE_LIBS #ifndef USE_CMAKE_LIBS
#pragma comment(lib, "opencv_world" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_world" OPENCV_VERSION ".lib")
#ifdef TRACK_OPTFLOW #ifdef TRACK_OPTFLOW
/*
#pragma comment(lib, "opencv_cudaoptflow" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_cudaoptflow" OPENCV_VERSION ".lib")
#pragma comment(lib, "opencv_cudaimgproc" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_cudaimgproc" OPENCV_VERSION ".lib")
#pragma comment(lib, "opencv_core" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_core" OPENCV_VERSION ".lib")
#pragma comment(lib, "opencv_imgproc" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_imgproc" OPENCV_VERSION ".lib")
#pragma comment(lib, "opencv_highgui" OPENCV_VERSION ".lib") #pragma comment(lib, "opencv_highgui" OPENCV_VERSION ".lib")
*/
#endif // TRACK_OPTFLOW #endif // TRACK_OPTFLOW
#endif // USE_CMAKE_LIBS #endif // USE_CMAKE_LIBS
#else // OpenCV 2.x #else // OpenCV 2.x

@ -257,7 +257,12 @@ void delta_yolo_class(float *output, float *delta, int index, int class_id, int
{ {
int n; int n;
if (delta[index + stride*class_id]){ if (delta[index + stride*class_id]){
delta[index + stride*class_id] = (1 - label_smooth_eps) - output[index + stride*class_id]; float y_true = 1;
if(label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
float result_delta = y_true - output[index + stride*class_id];
if(!isnan(result_delta) && !isinf(result_delta)) delta[index + stride*class_id] = result_delta;
//delta[index + stride*class_id] = 1 - output[index + stride*class_id];
if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(avg_cat) *avg_cat += output[index + stride*class_id]; if(avg_cat) *avg_cat += output[index + stride*class_id];
return; return;
@ -285,7 +290,11 @@ void delta_yolo_class(float *output, float *delta, int index, int class_id, int
else { else {
// default // default
for (n = 0; n < classes; ++n) { for (n = 0; n < classes; ++n) {
delta[index + stride*n] = ((n == class_id) ? (1 - label_smooth_eps) : (0 + label_smooth_eps/classes)) - output[index + stride*n]; float y_true = ((n == class_id) ? 1 : 0);
if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
float result_delta = y_true - output[index + stride*n];
if (!isnan(result_delta) && !isinf(result_delta)) delta[index + stride*n] = result_delta;
if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id];
if (n == class_id && avg_cat) *avg_cat += output[index + stride*n]; if (n == class_id && avg_cat) *avg_cat += output[index + stride*n];
} }
@ -362,17 +371,18 @@ void forward_yolo_layer(const layer l, network_state state)
for (t = 0; t < l.max_boxes; ++t) { for (t = 0; t < l.max_boxes; ++t) {
box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1); box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
int class_id = state.truth[t*(4 + 1) + b*l.truths + 4]; int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
if (class_id >= l.classes) { if (class_id >= l.classes || class_id < 0) {
printf("\n Warning: in txt-labels class_id=%d >= classes=%d in cfg-file. In txt-labels class_id should be [from 0 to %d] \n", class_id, l.classes, l.classes - 1); printf("\n Warning: in txt-labels class_id=%d >= classes=%d in cfg-file. In txt-labels class_id should be [from 0 to %d] \n", class_id, l.classes, l.classes - 1);
printf("\n truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f, class_id = %d \n", truth.x, truth.y, truth.w, truth.h, class_id); printf("\n truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f, class_id = %d \n", truth.x, truth.y, truth.w, truth.h, class_id);
if (check_mistakes) getchar(); if (check_mistakes) getchar();
continue; // if label contains class_id more than number of classes in the cfg-file continue; // if label contains class_id more than number of classes in the cfg-file and class_id check garbage value
} }
if (!truth.x) break; // continue; if (!truth.x) break; // continue;
int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1); int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
float objectness = l.output[obj_index]; float objectness = l.output[obj_index];
if (isnan(objectness) || isinf(objectness)) l.output[obj_index] = 0;
int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id, 0.25f); int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id, 0.25f);
float iou = box_iou(pred, truth); float iou = box_iou(pred, truth);
@ -415,7 +425,7 @@ void forward_yolo_layer(const layer l, network_state state)
system(buff); system(buff);
} }
int class_id = state.truth[t*(4 + 1) + b*l.truths + 4]; int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
if (class_id >= l.classes) continue; // if label contains class_id more than number of classes in the cfg-file if (class_id >= l.classes || class_id < 0) continue; // if label contains class_id more than number of classes in the cfg-file and class_id check garbage value
if (!truth.x) break; // continue; if (!truth.x) break; // continue;
float best_iou = 0; float best_iou = 0;
@ -531,6 +541,9 @@ void forward_yolo_layer(const layer l, network_state state)
} }
} }
if (count == 0) count = 1;
if (class_count == 0) class_count = 1;
//*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2); //*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
//printf("Region %d Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, .5R: %f, .75R: %f, count: %d\n", state.index, avg_iou / count, avg_cat / class_count, avg_obj / count, avg_anyobj / (l.w*l.h*l.n*l.batch), recall / count, recall75 / count, count); //printf("Region %d Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, .5R: %f, .75R: %f, count: %d\n", state.index, avg_iou / count, avg_cat / class_count, avg_obj / count, avg_anyobj / (l.w*l.h*l.n*l.batch), recall / count, recall75 / count, count);
@ -810,6 +823,6 @@ void forward_yolo_layer_gpu(const layer l, network_state state)
void backward_yolo_layer_gpu(const layer l, network_state state) void backward_yolo_layer_gpu(const layer l, network_state state)
{ {
axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1); axpy_ongpu(l.batch*l.inputs, state.net.loss_scale, l.delta_gpu, 1, state.delta, 1);
} }
#endif #endif

Loading…
Cancel
Save