Adam optimizer fixed

pull/2160/head
AlexeyAB 6 years ago
parent 64e478db07
commit 08f0f80b66
  1. 9
      src/blas_kernels.cu

@ -140,19 +140,20 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int
check_error(cudaPeekAtLastError());
}
__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return;
x[index] = x[index] - (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps));
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
float mhat = m[index] / (1.f - powf(B1, t));
float vhat = v[index] / (1.f - powf(B2, t));
x[index] = x[index] + rate * mhat / (sqrtf(vhat) + eps);
}
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t);
adam_kernel << <cuda_gridsize(n), BLOCK >> >(n, x, m, v, B1, B2, rate, eps, t);
check_error(cudaPeekAtLastError());
}

Loading…
Cancel
Save