diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index d8ff25f4..67504e71 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -275,15 +275,27 @@ __global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu, } // https://github.com/digantamisra98/Mish -__global__ void gradient_array_mish_kernel(int n, float *activation_input, float *delta) +__global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, float *delta) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if (i < n) { - float x = activation_input[i]; - float d = 2 * expf(x) + expf(2 * x) + 2; - float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); - float derivative = expf(x) * w / (d * d); - delta[i] *= derivative; + const float THRESHOLD = 20.0f; + + // implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed + // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 + float inp = activation_input_gpu[i]; + const float sp = (inp < THRESHOLD) ? log1p(exp(inp)) : inp; + const float grad_sp = 1 - exp(-sp); + const float tsp = tanh(sp); + const float grad_tsp = (1 - tsp*tsp) * grad_sp; + const float grad = inp * grad_tsp + tsp; + delta[i] *= grad; + + //float x = activation_input[i]; + //float d = 2 * expf(x) + expf(2 * x) + 2; + //float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); + //float derivative = expf(x) * w / (d * d); + //delta[i] *= derivative; } } diff --git a/src/activations.c b/src/activations.c index da92af0a..55b060bd 100644 --- a/src/activations.c +++ b/src/activations.c @@ -207,10 +207,23 @@ void gradient_array_mish(const int n, const float * activation_input, float * de int i; #pragma omp parallel for for (i = 0; i < n; ++i) { - float x = activation_input[i]; - float d = 2 * expf(x) + expf(2 * x) + 2; - float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); - float derivative = expf(x) * w / (d * d); - delta[i] *= derivative; + const float THRESHOLD = 20.0f; + + // implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed + // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 + float inp = activation_input[i]; + const float sp = (inp < THRESHOLD) ? log1p(exp(inp)) : inp; + const float grad_sp = 1 - exp(-sp); + const float tsp = tanh(sp); + const float grad_tsp = (1 - tsp*tsp) * grad_sp; + const float grad = inp * grad_tsp + tsp; + delta[i] *= grad; + + + //float x = activation_input[i]; + //float d = 2 * expf(x) + expf(2 * x) + 2; + //float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); + //float derivative = expf(x) * w / (d * d); + //delta[i] *= derivative; } }