Adam optimizer fixed

pull/2160/head
AlexeyAB 6 years ago
parent 64e478db07
commit 08f0f80b66
  1. 17
      src/blas_kernels.cu

@ -99,7 +99,7 @@ __global__ void dot_kernel(float *output, float scale, int batch, int n, int siz
int f1 = index / n; int f1 = index / n;
int f2 = index % n; int f2 = index % n;
if (f2 <= f1) return; if (f2 <= f1) return;
float sum = 0; float sum = 0;
float norm1 = 0; float norm1 = 0;
float norm2 = 0; float norm2 = 0;
@ -140,19 +140,20 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int
check_error(cudaPeekAtLastError()); check_error(cudaPeekAtLastError());
} }
__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t) __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{ {
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return; if (index >= N) return;
x[index] = x[index] - (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps)); float mhat = m[index] / (1.f - powf(B1, t));
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrt(v[index]) + eps))); float vhat = v[index] / (1.f - powf(B2, t));
x[index] = x[index] + rate * mhat / (sqrtf(vhat) + eps);
} }
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t) extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{ {
adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t); adam_kernel << <cuda_gridsize(n), BLOCK >> >(n, x, m, v, B1, B2, rate, eps, t);
check_error(cudaPeekAtLastError()); check_error(cudaPeekAtLastError());
} }
@ -175,7 +176,7 @@ __global__ void normalize_kernel(int N, float *x, float *mean, float *variance,
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return; if (index >= N) return;
int f = (index/spatial)%filters; int f = (index/spatial)%filters;
x[index] = (x[index] - mean[f])/(sqrtf(variance[f]) + .000001f); x[index] = (x[index] - mean[f])/(sqrtf(variance[f]) + .000001f);
} }
@ -184,7 +185,7 @@ __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *vari
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return; if (index >= N) return;
int f = (index/spatial)%filters; int f = (index/spatial)%filters;
delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch); delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
} }

Loading…
Cancel
Save