mirror of https://github.com/AlexeyAB/darknet.git
parent
5ef74c2031
commit
cd8d53df21
22 changed files with 752 additions and 176 deletions
@ -0,0 +1,28 @@ |
|||||||
|
typedef enum{ |
||||||
|
SIGMOID, RELU, LINEAR, RAMP, TANH |
||||||
|
}ACTIVATION; |
||||||
|
|
||||||
|
float activate(float x, ACTIVATION a, float dropout) |
||||||
|
{ |
||||||
|
//if((float)rand()/RAND_MAX < dropout) return 0; |
||||||
|
switch(a){ |
||||||
|
case LINEAR: |
||||||
|
return linear_activate(x)/(1-dropout); |
||||||
|
case SIGMOID: |
||||||
|
return sigmoid_activate(x)/(1-dropout); |
||||||
|
case RELU: |
||||||
|
return relu_activate(x)/(1-dropout); |
||||||
|
case RAMP: |
||||||
|
return ramp_activate(x)/(1-dropout); |
||||||
|
case TANH: |
||||||
|
return tanh_activate(x)/(1-dropout); |
||||||
|
} |
||||||
|
return 0; |
||||||
|
} |
||||||
|
|
||||||
|
__kernel void activate_array(__global float *x, |
||||||
|
const int n, const ACTIVATION a, const float dropout) |
||||||
|
{ |
||||||
|
int i = get_global_id(0); |
||||||
|
x[i] = activate(x[i], a, dropout); |
||||||
|
} |
@ -0,0 +1,14 @@ |
|||||||
|
#include "mini_blas.h" |
||||||
|
|
||||||
|
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY) |
||||||
|
{ |
||||||
|
int i; |
||||||
|
for(i = 0; i < N; ++i) Y[i*INCY] += ALPHA*X[i*INCX]; |
||||||
|
} |
||||||
|
|
||||||
|
void scal_cpu(int N, float ALPHA, float *X, int INCX) |
||||||
|
{ |
||||||
|
int i; |
||||||
|
for(i = 0; i < N; ++i) X[i*INCX] *= ALPHA; |
||||||
|
} |
||||||
|
|
@ -0,0 +1,283 @@ |
|||||||
|
#include "mini_blas.h" |
||||||
|
|
||||||
|
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float BETA, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
#ifdef GPU |
||||||
|
gemm_gpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); |
||||||
|
#else |
||||||
|
gemm_cpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); |
||||||
|
#endif |
||||||
|
} |
||||||
|
|
||||||
|
void gemm_nn(int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
int i,j,k; |
||||||
|
for(i = 0; i < M; ++i){ |
||||||
|
for(k = 0; k < K; ++k){ |
||||||
|
register float A_PART = ALPHA*A[i*lda+k]; |
||||||
|
for(j = 0; j < N; ++j){ |
||||||
|
C[i*ldc+j] += A_PART*B[k*ldb+j]; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void gemm_nt(int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
int i,j,k; |
||||||
|
for(i = 0; i < M; ++i){ |
||||||
|
for(j = 0; j < N; ++j){ |
||||||
|
register float sum = 0; |
||||||
|
for(k = 0; k < K; ++k){ |
||||||
|
sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; |
||||||
|
} |
||||||
|
C[i*ldc+j] += sum; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
void gemm_tn(int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
int i,j,k; |
||||||
|
for(i = 0; i < M; ++i){ |
||||||
|
for(k = 0; k < K; ++k){ |
||||||
|
register float A_PART = ALPHA*A[k*lda+i]; |
||||||
|
for(j = 0; j < N; ++j){ |
||||||
|
C[i*ldc+j] += A_PART*B[k*ldb+j]; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
void gemm_tt(int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
int i,j,k; |
||||||
|
for(i = 0; i < M; ++i){ |
||||||
|
for(j = 0; j < N; ++j){ |
||||||
|
for(k = 0; k < K; ++k){ |
||||||
|
C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb]; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float BETA, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
int i, j; |
||||||
|
for(i = 0; i < M; ++i){ |
||||||
|
for(j = 0; j < N; ++j){ |
||||||
|
C[i*ldc + j] *= BETA; |
||||||
|
} |
||||||
|
} |
||||||
|
if(!TA && !TB) |
||||||
|
gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc); |
||||||
|
else if(TA && !TB) |
||||||
|
gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc); |
||||||
|
else if(!TA && TB) |
||||||
|
gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc); |
||||||
|
else |
||||||
|
gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc); |
||||||
|
} |
||||||
|
|
||||||
|
#ifdef GPU |
||||||
|
|
||||||
|
#include "opencl.h" |
||||||
|
#include <math.h> |
||||||
|
|
||||||
|
#define STR_HELPER(x) #x |
||||||
|
#define STR(x) STR_HELPER(x) |
||||||
|
|
||||||
|
#define BLOCK 8 |
||||||
|
|
||||||
|
cl_kernel get_gemm_kernel() |
||||||
|
{ |
||||||
|
static int init = 0; |
||||||
|
static cl_kernel gemm_kernel; |
||||||
|
if(!init){ |
||||||
|
gemm_kernel = get_kernel("src/gemm.cl", "gemm", "-D BLOCK=" STR(BLOCK) ); |
||||||
|
init = 1; |
||||||
|
} |
||||||
|
return gemm_kernel; |
||||||
|
} |
||||||
|
|
||||||
|
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||||
|
cl_mem A_gpu, int lda,
|
||||||
|
cl_mem B_gpu, int ldb, |
||||||
|
float BETA, |
||||||
|
cl_mem C_gpu, int ldc) |
||||||
|
{ |
||||||
|
cl_setup(); |
||||||
|
cl_kernel gemm_kernel = get_gemm_kernel(); |
||||||
|
cl_command_queue queue = cl.queue; |
||||||
|
|
||||||
|
cl_uint i = 0; |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu); |
||||||
|
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
const size_t global_size[] = {ceil((float)M/BLOCK)*BLOCK, ceil((float)N/BLOCK)*BLOCK}; |
||||||
|
const size_t local_size[] = {BLOCK, BLOCK}; |
||||||
|
|
||||||
|
clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0); |
||||||
|
check_error(cl); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||||
|
float *A, int lda,
|
||||||
|
float *B, int ldb, |
||||||
|
float BETA, |
||||||
|
float *C, int ldc) |
||||||
|
{ |
||||||
|
cl_setup(); |
||||||
|
cl_context context = cl.context; |
||||||
|
cl_command_queue queue = cl.queue; |
||||||
|
|
||||||
|
size_t size = sizeof(float)*(TA ? lda*K:lda*M); |
||||||
|
cl_mem A_gpu = clCreateBuffer(context, |
||||||
|
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, |
||||||
|
size, A, &cl.error); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
size = sizeof(float)*(TB ? ldb*N:ldb*K); |
||||||
|
cl_mem B_gpu = clCreateBuffer(context, |
||||||
|
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, |
||||||
|
size, B, &cl.error); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
size = sizeof(float)*(ldc*M); |
||||||
|
cl_mem C_gpu = clCreateBuffer(context, |
||||||
|
CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR, |
||||||
|
size, C, &cl.error); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc); |
||||||
|
|
||||||
|
clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
clReleaseMemObject(A_gpu); |
||||||
|
clReleaseMemObject(B_gpu); |
||||||
|
clReleaseMemObject(C_gpu); |
||||||
|
} |
||||||
|
|
||||||
|
#include <stdio.h> |
||||||
|
#include <stdlib.h> |
||||||
|
#include <string.h> |
||||||
|
#include <time.h> |
||||||
|
|
||||||
|
void time_gpu_random_matrix(int TA, int TB, int m, int k, int n) |
||||||
|
{ |
||||||
|
float *a; |
||||||
|
if(!TA) a = random_matrix(m,k); |
||||||
|
else a = random_matrix(k,m); |
||||||
|
int lda = (!TA)?k:m; |
||||||
|
float *b; |
||||||
|
if(!TB) b = random_matrix(k,n); |
||||||
|
else b = random_matrix(n,k); |
||||||
|
int ldb = (!TB)?n:k; |
||||||
|
|
||||||
|
float *c = random_matrix(m,n); |
||||||
|
int i; |
||||||
|
clock_t start = clock(), end; |
||||||
|
for(i = 0; i<1000; ++i){ |
||||||
|
gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); |
||||||
|
} |
||||||
|
end = clock(); |
||||||
|
printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); |
||||||
|
free(a); |
||||||
|
free(b); |
||||||
|
free(c); |
||||||
|
} |
||||||
|
|
||||||
|
void test_gpu_accuracy(int TA, int TB, int m, int k, int n) |
||||||
|
{ |
||||||
|
srand(0); |
||||||
|
float *a; |
||||||
|
if(!TA) a = random_matrix(m,k); |
||||||
|
else a = random_matrix(k,m); |
||||||
|
int lda = (!TA)?k:m; |
||||||
|
float *b; |
||||||
|
if(!TB) b = random_matrix(k,n); |
||||||
|
else b = random_matrix(n,k); |
||||||
|
int ldb = (!TB)?n:k; |
||||||
|
|
||||||
|
float *c = random_matrix(m,n); |
||||||
|
float *c_gpu = random_matrix(m,n); |
||||||
|
memset(c, 0, m*n*sizeof(float)); |
||||||
|
memset(c_gpu, 0, m*n*sizeof(float)); |
||||||
|
int i; |
||||||
|
//pm(m,k,b);
|
||||||
|
gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n); |
||||||
|
//pm(m, n, c_gpu);
|
||||||
|
gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); |
||||||
|
//pm(m, n, c);
|
||||||
|
double sse = 0; |
||||||
|
for(i = 0; i < m*n; ++i) { |
||||||
|
//printf("%f %f\n", c[i], c_gpu[i]);
|
||||||
|
sse += pow(c[i]-c_gpu[i], 2); |
||||||
|
} |
||||||
|
printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g MSE\n",m,k,k,n, TA, TB, sse/(m*n)); |
||||||
|
free(a); |
||||||
|
free(b); |
||||||
|
free(c); |
||||||
|
} |
||||||
|
|
||||||
|
void test_gpu_blas() |
||||||
|
{ |
||||||
|
test_gpu_accuracy(0,0,17,10,10);
|
||||||
|
test_gpu_accuracy(1,0,17,10,10);
|
||||||
|
test_gpu_accuracy(0,1,17,10,10);
|
||||||
|
test_gpu_accuracy(1,1,17,10,10);
|
||||||
|
|
||||||
|
test_gpu_accuracy(0,0,1000,10,100);
|
||||||
|
test_gpu_accuracy(1,0,1000,10,100);
|
||||||
|
test_gpu_accuracy(0,1,1000,10,100);
|
||||||
|
test_gpu_accuracy(1,1,1000,10,100);
|
||||||
|
|
||||||
|
time_gpu_random_matrix(0,0,1000,1000,100);
|
||||||
|
time_random_matrix(0,0,1000,1000,100);
|
||||||
|
|
||||||
|
time_gpu_random_matrix(0,1,1000,1000,100);
|
||||||
|
time_random_matrix(0,1,1000,1000,100);
|
||||||
|
|
||||||
|
time_gpu_random_matrix(1,0,1000,1000,100);
|
||||||
|
time_random_matrix(1,0,1000,1000,100);
|
||||||
|
|
||||||
|
time_gpu_random_matrix(1,1,1000,1000,100);
|
||||||
|
time_random_matrix(1,1,1000,1000,100);
|
||||||
|
|
||||||
|
} |
||||||
|
#endif |
||||||
|
|
@ -0,0 +1,121 @@ |
|||||||
|
#include "mini_blas.h" |
||||||
|
|
||||||
|
//From Berkeley Vision's Caffe!
|
||||||
|
//https://github.com/BVLC/caffe/blob/master/LICENSE
|
||||||
|
void im2col_cpu(float* data_im, |
||||||
|
const int batch, const int channels, const int height, const int width, |
||||||
|
const int ksize, const int stride, float* data_col)
|
||||||
|
{ |
||||||
|
int c,h,w,b; |
||||||
|
int height_col = (height - ksize) / stride + 1; |
||||||
|
int width_col = (width - ksize) / stride + 1; |
||||||
|
int channels_col = channels * ksize * ksize; |
||||||
|
int im_size = height*width*channels; |
||||||
|
int col_size = height_col*width_col*channels_col; |
||||||
|
for(b = 0; b < batch; ++b){ |
||||||
|
for ( c = 0; c < channels_col; ++c) { |
||||||
|
int w_offset = c % ksize; |
||||||
|
int h_offset = (c / ksize) % ksize; |
||||||
|
int c_im = c / ksize / ksize; |
||||||
|
for ( h = 0; h < height_col; ++h) { |
||||||
|
for ( w = 0; w < width_col; ++w) { |
||||||
|
data_col[(c * height_col + h) * width_col + w] = |
||||||
|
data_im[(c_im * height + h * stride + h_offset) * width |
||||||
|
+ w * stride + w_offset]; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
data_im += im_size; |
||||||
|
data_col+= col_size; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
#ifdef GPU |
||||||
|
|
||||||
|
#include "opencl.h" |
||||||
|
#include <math.h> |
||||||
|
|
||||||
|
cl_kernel get_im2col_kernel() |
||||||
|
{ |
||||||
|
static int init = 0; |
||||||
|
static cl_kernel im2col_kernel; |
||||||
|
if(!init){ |
||||||
|
im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0); |
||||||
|
init = 1; |
||||||
|
} |
||||||
|
return im2col_kernel; |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
void im2col_ongpu(cl_mem data_im, const int batch, |
||||||
|
const int channels, const int height, const int width, |
||||||
|
const int ksize, const int stride, cl_mem data_col)
|
||||||
|
{ |
||||||
|
cl_setup(); |
||||||
|
cl_kernel im2col_kernel = get_im2col_kernel(); |
||||||
|
cl_command_queue queue = cl.queue; |
||||||
|
|
||||||
|
cl_uint i = 0; |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride); |
||||||
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
int height_col = (height - ksize) / stride + 1; |
||||||
|
int width_col = (width - ksize) / stride + 1; |
||||||
|
int channels_col = channels * ksize * ksize; |
||||||
|
|
||||||
|
size_t global_size[2]; |
||||||
|
size_t local_size[2]; |
||||||
|
global_size[0] = batch; |
||||||
|
global_size[1] = channels_col; |
||||||
|
local_size[0] = height_col; |
||||||
|
local_size[1] = width_col; |
||||||
|
|
||||||
|
clEnqueueNDRangeKernel(queue, im2col_kernel, 2, 0, |
||||||
|
global_size, local_size, 0, 0, 0); |
||||||
|
check_error(cl); |
||||||
|
} |
||||||
|
|
||||||
|
void im2col_gpu(float *data_im, |
||||||
|
const int batch, const int channels, const int height, const int width, |
||||||
|
const int ksize, const int stride, |
||||||
|
float *data_col)
|
||||||
|
{ |
||||||
|
cl_setup(); |
||||||
|
cl_context context = cl.context; |
||||||
|
cl_command_queue queue = cl.queue; |
||||||
|
|
||||||
|
size_t size = sizeof(float)*(channels*height*width*batch); |
||||||
|
cl_mem im_gpu = clCreateBuffer(context, |
||||||
|
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, |
||||||
|
size, data_im, &cl.error); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
int height_col = (height - ksize) / stride + 1; |
||||||
|
int width_col = (width - ksize) / stride + 1; |
||||||
|
int channels_col = channels * ksize * ksize; |
||||||
|
|
||||||
|
size = sizeof(float)*(height_col*width_col*channels_col*batch); |
||||||
|
cl_mem col_gpu = clCreateBuffer(context, |
||||||
|
CL_MEM_WRITE_ONLY|CL_MEM_COPY_HOST_PTR, |
||||||
|
size, data_col, &cl.error); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
im2col_ongpu(im_gpu, batch, channels, height, width, |
||||||
|
ksize, stride, col_gpu); |
||||||
|
|
||||||
|
clEnqueueReadBuffer(queue, col_gpu, CL_TRUE, 0, size, data_col, 0, 0, 0); |
||||||
|
check_error(cl); |
||||||
|
|
||||||
|
clReleaseMemObject(col_gpu); |
||||||
|
clReleaseMemObject(im_gpu); |
||||||
|
} |
||||||
|
|
||||||
|
#endif |
@ -0,0 +1,26 @@ |
|||||||
|
|
||||||
|
__kernel void im2col(__global float *data_im, |
||||||
|
const int batch, const int channels, const int height, const int width, |
||||||
|
const int ksize, const int stride, __global float *data_col) |
||||||
|
{ |
||||||
|
int b = get_global_id(0); |
||||||
|
int c = get_global_id(1); |
||||||
|
|
||||||
|
int h = get_local_id(0); |
||||||
|
int w = get_local_id(1); |
||||||
|
|
||||||
|
int height_col = (height - ksize) / stride + 1; |
||||||
|
int width_col = (width - ksize) / stride + 1; |
||||||
|
int channels_col = channels * ksize * ksize; |
||||||
|
|
||||||
|
int im_offset = height*width*channels*b; |
||||||
|
int col_offset = height_col*width_col*channels_col*b; |
||||||
|
|
||||||
|
int w_offset = c % ksize; |
||||||
|
int h_offset = (c / ksize) % ksize; |
||||||
|
int c_im = c / ksize / ksize; |
||||||
|
|
||||||
|
data_col[(c * height_col + h) * width_col + w + col_offset] = |
||||||
|
data_im[(c_im * height + h * stride + h_offset) * width |
||||||
|
+ w * stride + w_offset + im_offset]; |
||||||
|
} |
Loading…
Reference in new issue