Fixed BiFPN and label_smothing for Detection

pull/4976/head
AlexeyAB 5 years ago
parent d79d2815b9
commit d11caf486d
  1. 2
      include/darknet.h
  2. 2
      src/activation_kernels.cu
  3. 1
      src/blas.h
  4. 64
      src/blas_kernels.cu
  5. 2
      src/darknet.c
  6. 2
      src/detector.c
  7. 23
      src/gaussian_yolo_layer.c
  8. 2
      src/network.c
  9. 6
      src/network_kernels.cu
  10. 10
      src/parser.c
  11. 17
      src/shortcut_layer.c
  12. 26
      src/yolo_layer.c

@ -324,6 +324,8 @@ struct layer {
int onlyforward; int onlyforward;
int stopbackward; int stopbackward;
int dont_update;
int burnin_update;
int dontload; int dontload;
int dontsave; int dontsave;
int dontloadscales; int dontloadscales;

@ -493,6 +493,7 @@ __global__ void activate_array_normalize_channels_softmax_kernel(float *x, int s
for (k = 0; k < channels; ++k) { for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels]; float val = x[wh_i + k * wh_step + b*wh_step*channels];
val = expf(val - max_val) / sum; val = expf(val - max_val) / sum;
if (isnan(val) || isinf(val)) val = 0;
output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val; output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
} }
} }
@ -535,6 +536,7 @@ __global__ void gradient_array_normalize_channels_softmax_kernel(float *x, int s
float delta = delta_gpu[index]; float delta = delta_gpu[index];
float grad = x[index] * (1 - x[index]); float grad = x[index] * (1 - x[index]);
delta = delta * grad; delta = delta * grad;
if (isnan(delta) || isinf(delta)) delta = 0;
delta_gpu[index] = delta; delta_gpu[index] = delta;
} }
} }

@ -60,6 +60,7 @@ void fix_nan_and_inf_cpu(float *input, size_t size);
#ifdef GPU #ifdef GPU
void constrain_weight_updates_ongpu(int N, float learning_rate, float *weights_gpu, float *weight_updates_gpu);
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
void simple_copy_ongpu(int size, float *src, float *dst); void simple_copy_ongpu(int size, float *src, float *dst);

@ -418,6 +418,24 @@ __global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, in
//else out[0] = x[0]; //else out[0] = x[0];
} }
__global__ void constrain_weight_updates_kernel(int N, float learning_rate, float *weights_gpu, float *weight_updates_gpu)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < N) {
const float w = weights_gpu[i];
const float wu = weight_updates_gpu[i];
const float wu_sign = (wu == 0) ? 0 : (wu / wu);
const float abs_limit = fabs(w*learning_rate);
if (fabs(wu) > abs_limit) weight_updates_gpu[i] = abs_limit * wu_sign;
}
}
extern "C" void constrain_weight_updates_ongpu(int N, float learning_rate, float *weights_gpu, float *weight_updates_gpu)
{
constrain_weight_updates_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, learning_rate, weights_gpu, weight_updates_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
{ {
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -801,6 +819,12 @@ __device__ float relu(float src) {
return 0; return 0;
} }
__device__ float lrelu(float src) {
const float eps = 0.001;
if (src > eps) return src;
return eps;
}
__global__ void shortcut_singlelayer_simple_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalizion) __global__ void shortcut_singlelayer_simple_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalizion)
{ {
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -852,7 +876,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
for (int i = 0; i < (n + 1); ++i) { for (int i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float w = weights_gpu[weights_index]; const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) sum += relu(w); if (weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
} }
} }
@ -861,7 +885,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
if (weights_gpu) { if (weights_gpu) {
float w = weights_gpu[src_i / step]; float w = weights_gpu[src_i / step];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
out_val = in[id] * w; // [0 or c or (c, h ,w)] out_val = in[id] * w; // [0 or c or (c, h ,w)]
@ -879,7 +903,7 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
if (weights_gpu) { if (weights_gpu) {
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
float w = weights_gpu[weights_index]; float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
out_val += add[add_index] * w; // [0 or c or (c, h ,w)] out_val += add[add_index] * w; // [0 or c or (c, h ,w)]
@ -935,34 +959,27 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
for (i = 0; i < (n + 1); ++i) { for (i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)] const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float w = weights_gpu[weights_index]; const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) sum += relu(w); if (weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); else if (weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
} }
/*
grad = 0;
for (i = 0; i < (n + 1); ++i) {
const int weights_index = src_i / step + i*layer_step; // [0 or c or (c, h ,w)]
const float delta_w = delta_in[id] * in[id];
const float w = weights_gpu[weights_index];
if (weights_normalizion == RELU_NORMALIZATION) grad += delta_w * relu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad += delta_w * expf(w - max_val) / sum;
}
*/
} }
if (weights_gpu) { if (weights_gpu) {
float w = weights_gpu[src_i / step]; float w = weights_gpu[src_i / step];
if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
if (weights_normalizion == RELU_NORMALIZATION) grad = w; if (weights_normalizion == RELU_NORMALIZATION) grad = w;
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1-w); else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1-w);
delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)] delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)]
float weights_update_tmp = delta_in[id] * in[id] * grad / step; float weights_update_tmp = delta_in[id] * in[id] * grad;// / step;
if (layer_step == 1 && (size/32) > (id/32 + 1)) { if (layer_step == 1 && (size/32) > (id/32 + 1)) {
if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
weights_update_tmp = 0;
}
float wu = warpAllReduceSum(weights_update_tmp); float wu = warpAllReduceSum(weights_update_tmp);
if (threadIdx.x % 32 == 0) { if (threadIdx.x % 32 == 0) {
if (!isnan(wu) && !isinf(wu)) if (!isnan(wu) && !isinf(wu))
@ -997,13 +1014,18 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1 - w); else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1 - w);
layer_delta[add_index] += delta_in[id] * w; layer_delta[add_index] += delta_in[id] * w;
float weights_update_tmp = delta_in[id] * add[add_index] * grad / step; float weights_update_tmp = delta_in[id] * add[add_index] * grad;// / step;
if (layer_step == 1 && (size / 32) > (id / 32 + 1)) { if (layer_step == 1 && (size / 32) > (id / 32 + 1)) {
if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
weights_update_tmp = 0;
}
float wu = warpAllReduceSum(weights_update_tmp); float wu = warpAllReduceSum(weights_update_tmp);
if (threadIdx.x % 32 == 0) { if (threadIdx.x % 32 == 0) {
if (!isnan(wu) && !isinf(wu)) if (!isnan(wu) && !isinf(wu))
atomicAdd(&weight_updates_gpu[weights_index], wu); atomicAdd(&weight_updates_gpu[weights_index], wu);
//if(weights_gpu[weights_index] != 1) printf(" wu = %f, weights_update_tmp = %f, w = %f, weights_gpu[weights_index] = %f, grad = %f, weights_normalizion = %d ",
// wu, weights_update_tmp, w, weights_gpu[weights_index], grad, weights_normalizion);
} }
} }
else { else {
@ -1360,8 +1382,9 @@ __global__ void fix_nan_and_inf_kernel(float *input, size_t size)
const int index = blockIdx.x*blockDim.x + threadIdx.x; const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) { if (index < size) {
float val = input[index]; float val = input[index];
if (isnan(val) || isinf(val)) if (isnan(val) || isinf(val)) {
input[index] = 1.0f / index; // pseudo random value input[index] = 1.0f / (fabs((float)index) + 1); // pseudo random value
}
} }
} }
@ -1380,8 +1403,9 @@ __global__ void reset_nan_and_inf_kernel(float *input, size_t size)
const int index = blockIdx.x*blockDim.x + threadIdx.x; const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) { if (index < size) {
float val = input[index]; float val = input[index];
if (isnan(val) || isinf(val)) if (isnan(val) || isinf(val)) {
input[index] = 0; input[index] = 0;
}
} }
} }

@ -172,7 +172,7 @@ void oneoff(char *cfgfile, char *weightfile, char *outfile)
void partial(char *cfgfile, char *weightfile, char *outfile, int max) void partial(char *cfgfile, char *weightfile, char *outfile, int max)
{ {
gpu_index = -1; gpu_index = -1;
network net = parse_network_cfg(cfgfile); network net = parse_network_cfg_custom(cfgfile, 1, 1);
if(weightfile){ if(weightfile){
load_weights_upto(&net, weightfile, max); load_weights_upto(&net, weightfile, max);
} }

@ -356,7 +356,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
save_weights(net, buff); save_weights(net, buff);
} }
if (iteration >= (iter_save_last + 100) || iteration % 100 == 0) { if (iteration >= (iter_save_last + 100) || (iteration % 100 == 0 && iteration > 1)) {
iter_save_last = iteration; iter_save_last = iteration;
#ifdef GPU #ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0); if (ngpus != 1) sync_nets(nets, ngpus, 0);

@ -371,25 +371,20 @@ void delta_gaussian_yolo_class(float *output, float *delta, int index, int class
{ {
int n; int n;
if (delta[index]){ if (delta[index]){
if (label_smooth_eps > 0) { float y_true = 1;
float out_val = output[index + stride*class_id] * (1 - label_smooth_eps) + 0.5*label_smooth_eps; if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*class_id] = 1 - out_val; delta[index + stride*class_id] = y_true - output[index + stride*class_id];
} //delta[index + stride*class_id] = 1 - output[index + stride*class_id];
else {
delta[index + stride*class_id] = 1 - output[index + stride*class_id];
}
if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(avg_cat) *avg_cat += output[index + stride*class_id]; if(avg_cat) *avg_cat += output[index + stride*class_id];
return; return;
} }
for(n = 0; n < classes; ++n){ for(n = 0; n < classes; ++n){
if (label_smooth_eps > 0) { float y_true = ((n == class_id) ? 1 : 0);
float out_val = output[index + stride*class_id] * (1 - label_smooth_eps) + 0.5*label_smooth_eps; if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*n] = ((n == class_id) ? 1 : 0) - out_val; delta[index + stride*n] = y_true - output[index + stride*n];
}
else {
delta[index + stride*n] = ((n == class_id) ? 1 : 0) - output[index + stride*n];
}
if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(n == class_id && avg_cat) *avg_cat += output[index + stride*n]; if(n == class_id && avg_cat) *avg_cat += output[index + stride*n];
} }

@ -1135,7 +1135,7 @@ void fuse_conv_batchnorm(network net)
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights); //cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
int i; int i;
for (i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]); for (i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]);
printf(" l->nweights = %d \n", l->nweights); printf(" l->nweights = %d, j = %d \n", l->nweights, j);
} }
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w) // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)

@ -169,7 +169,8 @@ void backward_network_gpu(network net, network_state state)
for(i = net.n-1; i >= 0; --i){ for(i = net.n-1; i >= 0; --i){
state.index = i; state.index = i;
layer l = net.layers[i]; layer l = net.layers[i];
if (l.stopbackward) break; if (l.stopbackward == 1) break;
if (l.stopbackward > get_current_iteration(net)) break;
if(i == 0){ if(i == 0){
state.input = original_input; state.input = original_input;
state.delta = original_delta; state.delta = original_delta;
@ -247,7 +248,8 @@ void update_network_gpu(network net)
layer l = net.layers[i]; layer l = net.layers[i];
l.t = get_current_batch(net); l.t = get_current_batch(net);
if (iteration_num > (net.max_batches * 1 / 2)) l.deform = 0; if (iteration_num > (net.max_batches * 1 / 2)) l.deform = 0;
if(l.update_gpu){ if (l.burnin_update && (l.burnin_update*net.burn_in > iteration_num)) continue;
if(l.update_gpu && l.dont_update < iteration_num){
l.update_gpu(l, update_batch, rate, net.momentum, net.decay, net.loss_scale); l.update_gpu(l, update_batch, rate, net.momentum, net.decay, net.loss_scale);
} }
} }

@ -1365,6 +1365,8 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
l.clip = option_find_float_quiet(options, "clip", 0); l.clip = option_find_float_quiet(options, "clip", 0);
l.dynamic_minibatch = net.dynamic_minibatch; l.dynamic_minibatch = net.dynamic_minibatch;
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0); l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
l.dont_update = option_find_int_quiet(options, "dont_update", 0);
l.burnin_update = option_find_int_quiet(options, "burnin_update", 0);
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0); l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
l.dontload = option_find_int_quiet(options, "dontload", 0); l.dontload = option_find_int_quiet(options, "dontload", 0);
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0); l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
@ -1556,8 +1558,14 @@ void save_shortcut_weights(layer l, FILE *fp)
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) { if (gpu_index >= 0) {
pull_shortcut_layer(l); pull_shortcut_layer(l);
printf("\n pull_shortcut_layer \n");
} }
#endif #endif
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weight_updates[i]);
printf(" l.nweights = %d - update \n", l.nweights);
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
printf(" l.nweights = %d \n\n", l.nweights);
int num = l.nweights; int num = l.nweights;
fwrite(l.weights, sizeof(float), num, fp); fwrite(l.weights, sizeof(float), num, fp);
} }
@ -1843,7 +1851,7 @@ void load_shortcut_weights(layer l, FILE *fp)
read_bytes = fread(l.weights, sizeof(float), num, fp); read_bytes = fread(l.weights, sizeof(float), num, fp);
if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index); if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index);
//for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]); //for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
//printf("\n\n"); //printf(" read_bytes = %d \n\n", read_bytes);
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) { if (gpu_index >= 0) {
push_shortcut_layer(l); push_shortcut_layer(l);

@ -256,10 +256,24 @@ void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, flo
reset_nan_and_inf(l.weight_updates_gpu, l.nweights); reset_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights); fix_nan_and_inf(l.weights_gpu, l.nweights);
axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); constrain_weight_updates_ongpu(l.nweights, 0.001, l.weights_gpu, l.weight_updates_gpu);
/*
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights);
CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream()));
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weight_updates[i]);
printf(" l.nweights = %d - updates \n", l.nweights);
for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
printf(" l.nweights = %d \n\n", l.nweights);
*/
//axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1); scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
//fill_ongpu(l.nweights, 0, l.weight_updates_gpu, 1);
//if (l.clip) { //if (l.clip) {
// constrain_ongpu(l.nweights, l.clip, l.weights_gpu, 1); // constrain_ongpu(l.nweights, l.clip, l.weights_gpu, 1);
//} //}
@ -268,6 +282,7 @@ void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, flo
void pull_shortcut_layer(layer l) void pull_shortcut_layer(layer l)
{ {
cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights);
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights); cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream())); CHECK_CUDA(cudaStreamSynchronize(get_cuda_stream()));

@ -257,13 +257,12 @@ void delta_yolo_class(float *output, float *delta, int index, int class_id, int
{ {
int n; int n;
if (delta[index + stride*class_id]){ if (delta[index + stride*class_id]){
if (label_smooth_eps > 0) { float y_true = 1;
float out_val = output[index + stride*class_id] * (1 - label_smooth_eps) + 0.5*label_smooth_eps; if(label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*class_id] = 1 - out_val; float result_delta = y_true - output[index + stride*class_id];
} if(!isnan(result_delta) && !isinf(result_delta)) delta[index + stride*class_id] = result_delta;
else { //delta[index + stride*class_id] = 1 - output[index + stride*class_id];
delta[index + stride*class_id] = 1 - output[index + stride*class_id];
}
if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id];
if(avg_cat) *avg_cat += output[index + stride*class_id]; if(avg_cat) *avg_cat += output[index + stride*class_id];
return; return;
@ -291,13 +290,11 @@ void delta_yolo_class(float *output, float *delta, int index, int class_id, int
else { else {
// default // default
for (n = 0; n < classes; ++n) { for (n = 0; n < classes; ++n) {
if (label_smooth_eps > 0) { float y_true = ((n == class_id) ? 1 : 0);
float out_val = output[index + stride*class_id] * (1 - label_smooth_eps) + 0.5*label_smooth_eps; if (label_smooth_eps) y_true = y_true * (1 - label_smooth_eps) + 0.5*label_smooth_eps;
delta[index + stride*n] = ((n == class_id) ? 1 : 0) - out_val; float result_delta = y_true - output[index + stride*n];
} if (!isnan(result_delta) && !isinf(result_delta)) delta[index + stride*n] = result_delta;
else {
delta[index + stride*n] = ((n == class_id) ? 1 : 0) - output[index + stride*n];
}
if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id]; if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id];
if (n == class_id && avg_cat) *avg_cat += output[index + stride*n]; if (n == class_id && avg_cat) *avg_cat += output[index + stride*n];
} }
@ -385,6 +382,7 @@ void forward_yolo_layer(const layer l, network_state state)
int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1); int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
float objectness = l.output[obj_index]; float objectness = l.output[obj_index];
if (isnan(objectness) || isinf(objectness)) l.output[obj_index] = 0;
int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id, 0.25f); int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id, 0.25f);
float iou = box_iou(pred, truth); float iou = box_iou(pred, truth);

Loading…
Cancel
Save