Fixed relu to lrelu in fuse_conv_batchnorm()

pull/5172/head
AlexeyAB 5 years ago
parent d6181c67cd
commit 2614a231f0
  1. 10
      src/network.c

@ -1092,6 +1092,12 @@ static float relu(float src) {
return 0; return 0;
} }
static float lrelu(float src) {
const float eps = 0.001;
if (src > eps) return src;
return eps;
}
void fuse_conv_batchnorm(network net) void fuse_conv_batchnorm(network net)
{ {
int j; int j;
@ -1160,14 +1166,14 @@ void fuse_conv_batchnorm(network net)
for (i = 0; i < (l->n + 1); ++i) { for (i = 0; i < (l->n + 1); ++i) {
int w_index = chan + i * layer_step; int w_index = chan + i * layer_step;
float w = l->weights[w_index]; float w = l->weights[w_index];
if (l->weights_normalizion == RELU_NORMALIZATION) sum += relu(w); if (l->weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
} }
for (i = 0; i < (l->n + 1); ++i) { for (i = 0; i < (l->n + 1); ++i) {
int w_index = chan + i * layer_step; int w_index = chan + i * layer_step;
float w = l->weights[w_index]; float w = l->weights[w_index];
if (l->weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; if (l->weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
l->weights[w_index] = w; l->weights[w_index] = w;
} }

Loading…
Cancel
Save