Added MISH activation, use activation=mish in [convolutional] layers

pull/4269/head
AlexeyAB 6 years ago
parent d628e8eab7
commit bf8ea4183d
  1. 6
      include/darknet.h
  2. 36
      src/activation_kernels.cu
  3. 25
      src/activations.c
  4. 6
      src/activations.h
  5. 9
      src/convolutional_kernels.cu
  6. 13
      src/convolutional_layer.c
  7. 4
      src/layer.c

@ -102,7 +102,7 @@ typedef struct tree {
// activations.h // activations.h
typedef enum { typedef enum {
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH
}ACTIVATION; }ACTIVATION;
// parser.h // parser.h
@ -347,7 +347,7 @@ struct layer {
float *col_image; float *col_image;
float * delta; float * delta;
float * output; float * output;
float * output_sigmoid; float * activation_input;
int delta_pinned; int delta_pinned;
int output_pinned; int output_pinned;
float * loss; float * loss;
@ -532,7 +532,7 @@ struct layer {
float * input_antialiasing_gpu; float * input_antialiasing_gpu;
float * output_gpu; float * output_gpu;
float * output_sigmoid_gpu; float * activation_input_gpu;
float * loss_gpu; float * loss_gpu;
float * delta_gpu; float * delta_gpu;
float * rand_gpu; float * rand_gpu;

@ -199,6 +199,16 @@ __global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmo
} }
} }
__global__ void activate_array_mish_kernel(float *x, int n, float *activation_input, float *output_gpu)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < n) {
float x_val = x[i];
activation_input[i] = x_val; // store value before activation
output_gpu[i] = x_val * tanh_activate_kernel(log(1 + expf(x_val)));
}
}
__global__ void activate_array_leaky_kernel(float *x, int n) __global__ void activate_array_leaky_kernel(float *x, int n)
{ {
int index = blockIdx.x*blockDim.x + threadIdx.x; int index = blockIdx.x*blockDim.x + threadIdx.x;
@ -263,6 +273,18 @@ __global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu,
} }
} }
__global__ void gradient_array_mish_kernel(int n, float *activation_input, float *delta)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < n) {
float x = activation_input[i];
float d = 2 * expf(x) + expf(2 * x) + 2;
float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
float derivative = expf(x) * w / (d * d);
delta[i] *= derivative;
}
}
__global__ void gradient_array_leaky_kernel(float *x, int n, float *delta) __global__ void gradient_array_leaky_kernel(float *x, int n, float *delta)
{ {
int index = blockIdx.x*blockDim.x + threadIdx.x; int index = blockIdx.x*blockDim.x + threadIdx.x;
@ -333,6 +355,13 @@ extern "C" void activate_array_swish_ongpu(float *x, int n, float *output_sigmoi
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
} }
extern "C" void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu)
{
const int num_blocks = get_number_of_blocks(n, BLOCK);
activate_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, activation_input_gpu, output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta) extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
{ {
const int num_blocks = get_number_of_blocks(n, BLOCK); const int num_blocks = get_number_of_blocks(n, BLOCK);
@ -354,4 +383,11 @@ extern "C" void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu,
const int num_blocks = get_number_of_blocks(n, BLOCK); const int num_blocks = get_number_of_blocks(n, BLOCK);
gradient_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, sigmoid_gpu, delta); gradient_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, sigmoid_gpu, delta);
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta)
{
const int num_blocks = get_number_of_blocks(n, BLOCK);
gradient_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (n, activation_input_gpu, delta);
CHECK_CUDA(cudaPeekAtLastError());
} }

@ -46,6 +46,7 @@ ACTIVATION get_activation(char *s)
{ {
if (strcmp(s, "logistic")==0) return LOGISTIC; if (strcmp(s, "logistic")==0) return LOGISTIC;
if (strcmp(s, "swish") == 0) return SWISH; if (strcmp(s, "swish") == 0) return SWISH;
if (strcmp(s, "mish") == 0) return MISH;
if (strcmp(s, "loggy")==0) return LOGGY; if (strcmp(s, "loggy")==0) return LOGGY;
if (strcmp(s, "relu")==0) return RELU; if (strcmp(s, "relu")==0) return RELU;
if (strcmp(s, "elu")==0) return ELU; if (strcmp(s, "elu")==0) return ELU;
@ -133,6 +134,17 @@ void activate_array_swish(float *x, const int n, float * output_sigmoid, float *
} }
} }
void activate_array_mish(float *x, const int n, float * activation_input, float * output)
{
int i;
#pragma omp parallel for
for (i = 0; i < n; ++i) {
float x_val = x[i];
activation_input[i] = x_val; // store value before activation
output[i] = x_val * tanh_activate(log(1 + expf(x_val)));
}
}
float gradient(float x, ACTIVATION a) float gradient(float x, ACTIVATION a)
{ {
switch(a){ switch(a){
@ -187,3 +199,16 @@ void gradient_array_swish(const float *x, const int n, const float * sigmoid, fl
delta[i] *= swish + sigmoid[i]*(1 - swish); delta[i] *= swish + sigmoid[i]*(1 - swish);
} }
} }
void gradient_array_mish(const int n, const float * activation_input, float * delta)
{
int i;
#pragma omp parallel for
for (i = 0; i < n; ++i) {
float x = activation_input[i];
float d = 2 * expf(x) + expf(2 * x) + 2;
float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
float derivative = expf(x) * w / (d * d);
delta[i] *= derivative;
}
}

@ -5,7 +5,7 @@
#include "math.h" #include "math.h"
//typedef enum{ //typedef enum{
// LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU // LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH
//}ACTIVATION; //}ACTIVATION;
#ifdef __cplusplus #ifdef __cplusplus
@ -18,13 +18,17 @@ float activate(float x, ACTIVATION a);
float gradient(float x, ACTIVATION a); float gradient(float x, ACTIVATION a);
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta); void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
void gradient_array_swish(const float *x, const int n, const float * sigmoid, float * delta); void gradient_array_swish(const float *x, const int n, const float * sigmoid, float * delta);
void gradient_array_mish(const int n, const float * activation_input, float * delta);
void activate_array(float *x, const int n, const ACTIVATION a); void activate_array(float *x, const int n, const ACTIVATION a);
void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output); void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output);
void activate_array_mish(float *x, const int n, float * activation_input, float * output);
#ifdef GPU #ifdef GPU
void activate_array_ongpu(float *x, int n, ACTIVATION a); void activate_array_ongpu(float *x, int n, ACTIVATION a);
void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu); void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu);
void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu);
void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta); void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta);
void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta); void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta);
void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta);
#endif #endif
static inline float stair_activate(float x) static inline float stair_activate(float x)

@ -392,7 +392,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
*/ */
//add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); //add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu); if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); else if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); //if(l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if (l.binary || l.xnor) swap_binary(&l); //if (l.binary || l.xnor) swap_binary(&l);
@ -596,7 +597,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//#ifndef CUDNN_HALF //#ifndef CUDNN_HALF
//#endif // no CUDNN_HALF //#endif // no CUDNN_HALF
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu); if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.dot > 0) dot_error_gpu(l); //if(l.dot > 0) dot_error_gpu(l);
if(l.binary || l.xnor) swap_binary(&l); if(l.binary || l.xnor) swap_binary(&l);
@ -639,7 +641,8 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if(state.net.try_fix_nan) constrain_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1); if(state.net.try_fix_nan) constrain_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1);
if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.delta_gpu); if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
if (!l.batch_normalize) if (!l.batch_normalize)

@ -473,10 +473,10 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.scale_v = (float*)calloc(n, sizeof(float)); l.scale_v = (float*)calloc(n, sizeof(float));
} }
if(l.activation == SWISH) l.output_sigmoid = (float*)calloc(total_batch*l.outputs, sizeof(float)); if (l.activation == SWISH || l.activation == MISH) l.activation_input = (float*)calloc(total_batch*l.outputs, sizeof(float));
#ifdef GPU #ifdef GPU
if (l.activation == SWISH) l.output_sigmoid_gpu = cuda_make_array(l.output_sigmoid, total_batch*out_h*out_w*n); if (l.activation == SWISH || l.activation == MISH) l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*out_h*out_w*n);
l.forward_gpu = forward_convolutional_layer_gpu; l.forward_gpu = forward_convolutional_layer_gpu;
l.backward_gpu = backward_convolutional_layer_gpu; l.backward_gpu = backward_convolutional_layer_gpu;
@ -1100,7 +1100,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
//activate_array(l.output, m*n*l.batch, l.activation); //activate_array(l.output, m*n*l.batch, l.activation);
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output); if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation); else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
return; return;
@ -1139,7 +1140,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
//activate_array(l.output, m*n*l.batch, l.activation); //activate_array(l.output, m*n*l.batch, l.activation);
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output); if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation); else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);
if(l.binary || l.xnor) swap_binary(&l); if(l.binary || l.xnor) swap_binary(&l);
@ -1276,7 +1278,8 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
int n = l.size*l.size*l.c / l.groups; int n = l.size*l.size*l.c / l.groups;
int k = l.out_w*l.out_h; int k = l.out_w*l.out_h;
if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.delta); if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.delta);
else if (l.activation == MISH) gradient_array_mish(l.outputs*l.batch, l.activation_input, l.delta);
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
if (l.batch_normalize) { if (l.batch_normalize) {

@ -90,7 +90,7 @@ void free_layer(layer l)
#endif // GPU #endif // GPU
if (l.delta) free(l.delta), l.delta = NULL; if (l.delta) free(l.delta), l.delta = NULL;
if (l.output) free(l.output), l.output = NULL; if (l.output) free(l.output), l.output = NULL;
if (l.output_sigmoid) free(l.output_sigmoid), l.output_sigmoid = NULL; if (l.activation_input) free(l.activation_input), l.activation_input = NULL;
if (l.squared) free(l.squared); if (l.squared) free(l.squared);
if (l.norms) free(l.norms); if (l.norms) free(l.norms);
if (l.spatial_mean) free(l.spatial_mean); if (l.spatial_mean) free(l.spatial_mean);
@ -176,7 +176,7 @@ void free_layer(layer l)
if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu), l.scale_updates_gpu = NULL; if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu), l.scale_updates_gpu = NULL;
if (l.input_antialiasing_gpu) cuda_free(l.input_antialiasing_gpu), l.input_antialiasing_gpu = NULL; if (l.input_antialiasing_gpu) cuda_free(l.input_antialiasing_gpu), l.input_antialiasing_gpu = NULL;
if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL; if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL;
if (l.output_sigmoid_gpu) cuda_free(l.output_sigmoid_gpu), l.output_sigmoid_gpu = NULL; if (l.activation_input_gpu) cuda_free(l.activation_input_gpu), l.activation_input_gpu = NULL;
if (l.delta_gpu) cuda_free(l.delta_gpu), l.delta_gpu = NULL; if (l.delta_gpu) cuda_free(l.delta_gpu), l.delta_gpu = NULL;
if (l.rand_gpu) cuda_free(l.rand_gpu); if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.squared_gpu) cuda_free(l.squared_gpu); if (l.squared_gpu) cuda_free(l.squared_gpu);

Loading…
Cancel
Save