diff --git a/include/darknet.h b/include/darknet.h index ac83788a..9b0ddeec 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -104,7 +104,7 @@ typedef struct tree { // activations.h typedef enum { - LOGISTIC, RELU, RELU6, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH, NORM_CHAN, NORM_CHAN_SOFTMAX, NORM_CHAN_SOFTMAX_MAXVAL + LOGISTIC, RELU, RELU6, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, GELU, SWISH, MISH, NORM_CHAN, NORM_CHAN_SOFTMAX, NORM_CHAN_SOFTMAX_MAXVAL }ACTIVATION; // parser.h diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index ba1113fd..89e01c04 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -36,6 +36,7 @@ __device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;} __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;} __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;} __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);} +__device__ float gelu_activate_kernel(float x){return (0.5*x*(1 + tanhf(0.797885*x + 0.035677*powf(x, 3))));} __device__ float softplus_kernel(float x, float threshold = 20) { if (x > threshold) return x; // too large else if (x < -threshold) return expf(x); // too small @@ -75,6 +76,11 @@ __device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01f;} __device__ float ramp_gradient_kernel(float x){return (x>0)+.1f;} __device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1f;} __device__ float tanh_gradient_kernel(float x){return 1-x*x;} +__device__ float sech_gpu(float x) { return 2 / (expf(x) + expf(-x)); } +__device__ float gelu_gradient_kernel(float x) { + const float x3 = powf(x, 3); + return 0.5*tanhf(0.0356774*x3 + 0.797885*x) + (0.0535161*x3 + 0.398942*x) * powf(sech_gpu(0.0356774*x3 + 0.797885*x), 2) + 0.5; +} __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01f : .125f;} __device__ float stair_gradient_kernel(float x) { @@ -99,6 +105,8 @@ __device__ float activate_kernel(float x, ACTIVATION a) return elu_activate_kernel(x); case SELU: return selu_activate_kernel(x); + case GELU: + return gelu_activate_kernel(x); case RELIE: return relie_activate_kernel(x); case RAMP: @@ -138,6 +146,8 @@ __device__ float gradient_kernel(float x, ACTIVATION a) return elu_gradient_kernel(x); case SELU: return selu_gradient_kernel(x); + case GELU: + return gelu_gradient_kernel(x); case RELIE: return relie_gradient_kernel(x); case RAMP: @@ -245,6 +255,14 @@ __global__ void activate_array_selu_kernel(float *x, int n) } } +__global__ void activate_array_gelu_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = gelu_activate_kernel(x[index]); + } +} + __global__ void activate_array_logistic_kernel(float *x, int n) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -343,6 +361,14 @@ __global__ void gradient_array_selu_kernel(float *x, int n, float *delta) } } +__global__ void gradient_array_gelu_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= gelu_gradient_kernel(x[index]); + } +} + __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -394,6 +420,7 @@ extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) else if (a == RELU) activate_array_relu_kernel << > >(x, n); else if (a == RELU6) activate_array_relu6_kernel << > >(x, n); else if (a == SELU) activate_array_selu_kernel << > >(x, n); + else if (a == GELU) activate_array_gelu_kernel << > >(x, n); else activate_array_kernel<<>>(x, n, a); CHECK_CUDA(cudaPeekAtLastError()); @@ -429,6 +456,7 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta exit(0); } else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); + else if (a == GELU) gradient_array_gelu_kernel << > >(x, n, delta); else gradient_array_kernel << > > (x, n, a, delta); CHECK_CUDA(cudaPeekAtLastError()); diff --git a/src/activations.c b/src/activations.c index 0b68fda0..dac252b6 100644 --- a/src/activations.c +++ b/src/activations.c @@ -19,6 +19,8 @@ char *get_activation_string(ACTIVATION a) return "elu"; case SELU: return "selu"; + case GELU: + return "gelu"; case RELIE: return "relie"; case RAMP: @@ -56,6 +58,7 @@ ACTIVATION get_activation(char *s) if (strcmp(s, "relu6") == 0) return RELU6; if (strcmp(s, "elu")==0) return ELU; if (strcmp(s, "selu") == 0) return SELU; + if (strcmp(s, "gelu") == 0) return GELU; if (strcmp(s, "relie")==0) return RELIE; if (strcmp(s, "plse")==0) return PLSE; if (strcmp(s, "hardtan")==0) return HARDTAN; @@ -84,6 +87,8 @@ float activate(float x, ACTIVATION a) return elu_activate(x); case SELU: return selu_activate(x); + case GELU: + return gelu_activate(x); case RELIE: return relie_activate(x); case RAMP: @@ -300,6 +305,8 @@ float gradient(float x, ACTIVATION a) return elu_gradient(x); case SELU: return selu_gradient(x); + case GELU: + return gelu_gradient(x); case RELIE: return relie_gradient(x); case RAMP: diff --git a/src/activations.h b/src/activations.h index 9e9b0053..8ac457d2 100644 --- a/src/activations.h +++ b/src/activations.h @@ -65,6 +65,7 @@ static inline float ramp_activate(float x){return x*(x>0)+.1f*x;} static inline float leaky_activate(float x){return (x>0) ? x : .1f*x;} //static inline float tanh_activate(float x){return (expf(2*x)-1)/(expf(2*x)+1);} static inline float tanh_activate(float x) { return (2 / (1 + expf(-2 * x)) - 1); } +static inline float gelu_activate(float x) { return (0.5*x*(1 + tanhf(0.797885*x + 0.035677*powf(x, 3)))); } static inline float softplus_activate(float x, float threshold) { if (x > threshold) return x; // too large else if (x < -threshold) return expf(x); // too small @@ -114,6 +115,12 @@ static inline float relie_gradient(float x){return (x>0) ? 1 : .01f;} static inline float ramp_gradient(float x){return (x>0)+.1f;} static inline float leaky_gradient(float x){return (x>0) ? 1 : .1f;} static inline float tanh_gradient(float x){return 1-x*x;} + +static inline float sech(float x) { return 2 / (expf(x) + expf(-x)); } +static inline float gelu_gradient(float x) { + const float x3 = powf(x, 3); + return 0.5*tanhf(0.0356774*x3 + 0.797885*x) + (0.0535161*x3 + 0.398942*x) * powf(sech(0.0356774*x3 + 0.797885*x), 2) + 0.5; +} static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01f : .125f;} #ifdef __cplusplus