diff --git a/src/network.c b/src/network.c index 8f09dab7..37463cec 100644 --- a/src/network.c +++ b/src/network.c @@ -1092,6 +1092,12 @@ static float relu(float src) { return 0; } +static float lrelu(float src) { + const float eps = 0.001; + if (src > eps) return src; + return eps; +} + void fuse_conv_batchnorm(network net) { int j; @@ -1160,14 +1166,14 @@ void fuse_conv_batchnorm(network net) for (i = 0; i < (l->n + 1); ++i) { int w_index = chan + i * layer_step; float w = l->weights[w_index]; - if (l->weights_normalizion == RELU_NORMALIZATION) sum += relu(w); + if (l->weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w); else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val); } for (i = 0; i < (l->n + 1); ++i) { int w_index = chan + i * layer_step; float w = l->weights[w_index]; - if (l->weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; + if (l->weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum; else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; l->weights[w_index] = w; }