|
|
|
@ -217,11 +217,33 @@ __global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmo |
|
|
|
|
if (i < n) { |
|
|
|
|
float x_val = x[i]; |
|
|
|
|
float sigmoid = logistic_activate_kernel(x_val); |
|
|
|
|
output_sigmoid_gpu[i] = sigmoid; |
|
|
|
|
if (output_sigmoid_gpu) output_sigmoid_gpu[i] = sigmoid; |
|
|
|
|
output_gpu[i] = x_val * sigmoid; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
__device__ float mish_njuffa(float x) |
|
|
|
|
{ |
|
|
|
|
float r; |
|
|
|
|
float e = expf(x); |
|
|
|
|
r = 1.0f / fmaf(fmaf(-0.5f, e, -1.0f), e, -1.0f); |
|
|
|
|
r = fmaf(r, x, x); |
|
|
|
|
return r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
__device__ float mish_yashas(float x) |
|
|
|
|
{ |
|
|
|
|
auto e = __expf(x); |
|
|
|
|
if (x <= -18.0f) |
|
|
|
|
return x * e; |
|
|
|
|
|
|
|
|
|
auto n = e * e + 2 * e; |
|
|
|
|
if (x <= -5.0f) |
|
|
|
|
return x * __fdividef(n, n + 2); |
|
|
|
|
|
|
|
|
|
return x - 2 * __fdividef(x, n + 2); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// https://github.com/digantamisra98/Mish |
|
|
|
|
__global__ void activate_array_mish_kernel(float *x, int n, float *activation_input, float *output_gpu) |
|
|
|
|
{ |
|
|
|
@ -229,13 +251,15 @@ __global__ void activate_array_mish_kernel(float *x, int n, float *activation_in |
|
|
|
|
if (i < n) { |
|
|
|
|
const float MISH_THRESHOLD = 20; |
|
|
|
|
float x_val = x[i]; |
|
|
|
|
activation_input[i] = x_val; // store value before activation |
|
|
|
|
if (activation_input) activation_input[i] = x_val; // store value before activation |
|
|
|
|
//output_gpu[i] = x_val * tanh_activate_kernel(logf(1 + expf(x_val))); |
|
|
|
|
|
|
|
|
|
// Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20 |
|
|
|
|
// TF: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L40-L49 |
|
|
|
|
// log1p(x) == log(x + 1) |
|
|
|
|
output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) ); |
|
|
|
|
//output_gpu[i] = mish_yashas(x_val); |
|
|
|
|
//output_gpu[i] = mish_njuffa(x_val); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|