From f7a6f7b87cdfd49c9930d2c2a2d7fa5b52b30940 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Thu, 21 Nov 2019 14:11:52 +0300 Subject: [PATCH] Fixed MISH as in thomasbrandon/mish-cuda implementation with 1 Threshold --- src/activation_kernels.cu | 11 ++++++++--- src/activations.c | 9 ++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 67504e71..6ef165ce 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -204,9 +204,14 @@ __global__ void activate_array_mish_kernel(float *x, int n, float *activation_in { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if (i < n) { + const float MISH_THRESHOLD = 20; float x_val = x[i]; activation_input[i] = x_val; // store value before activation - output_gpu[i] = x_val * tanh_activate_kernel(log(1 + expf(x_val))); + //output_gpu[i] = x_val * tanh_activate_kernel(log(1 + expf(x_val))); + + // https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20 + if (x_val < MISH_THRESHOLD) output_gpu[i] = x_val * tanh_activate_kernel(log(expf(x_val))); + else output_gpu[i] = x_val * tanh_activate_kernel(x_val); } } @@ -279,12 +284,12 @@ __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, f { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if (i < n) { - const float THRESHOLD = 20.0f; + const float MISH_THRESHOLD = 20.0f; // implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 float inp = activation_input_gpu[i]; - const float sp = (inp < THRESHOLD) ? log1p(exp(inp)) : inp; + const float sp = (inp < MISH_THRESHOLD) ? log1p(exp(inp)) : inp; const float grad_sp = 1 - exp(-sp); const float tsp = tanh(sp); const float grad_tsp = (1 - tsp*tsp) * grad_sp; diff --git a/src/activations.c b/src/activations.c index 55b060bd..83580cb2 100644 --- a/src/activations.c +++ b/src/activations.c @@ -137,12 +137,15 @@ void activate_array_swish(float *x, const int n, float * output_sigmoid, float * // https://github.com/digantamisra98/Mish void activate_array_mish(float *x, const int n, float * activation_input, float * output) { + const float MISH_THRESHOLD = 20; int i; #pragma omp parallel for for (i = 0; i < n; ++i) { float x_val = x[i]; activation_input[i] = x_val; // store value before activation - output[i] = x_val * tanh_activate(log(1 + expf(x_val))); + //output[i] = x_val * tanh_activate(log(1 + expf(x_val))); + if (x_val < MISH_THRESHOLD) output[i] = x_val * tanh_activate(log(expf(x_val))); + else output[i] = x_val * tanh_activate(x_val); } } @@ -207,12 +210,12 @@ void gradient_array_mish(const int n, const float * activation_input, float * de int i; #pragma omp parallel for for (i = 0; i < n; ++i) { - const float THRESHOLD = 20.0f; + const float MISH_THRESHOLD = 20.0f; // implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 float inp = activation_input[i]; - const float sp = (inp < THRESHOLD) ? log1p(exp(inp)) : inp; + const float sp = (inp < MISH_THRESHOLD) ? log1p(exp(inp)) : inp; const float grad_sp = 1 - exp(-sp); const float tsp = tanh(sp); const float grad_tsp = (1 - tsp*tsp) * grad_sp;