|
|
|
@ -51,12 +51,23 @@ void im2col_cpu(float* data_im, int batch, |
|
|
|
|
#include "opencl.h" |
|
|
|
|
#include <math.h> |
|
|
|
|
|
|
|
|
|
cl_kernel get_im2col_kernel() |
|
|
|
|
cl_kernel get_im2col_pad_kernel() |
|
|
|
|
{ |
|
|
|
|
static int init = 0; |
|
|
|
|
static cl_kernel im2col_kernel; |
|
|
|
|
if(!init){ |
|
|
|
|
im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0); |
|
|
|
|
im2col_kernel = get_kernel("src/im2col.cl", "im2col_pad", 0); |
|
|
|
|
init = 1; |
|
|
|
|
} |
|
|
|
|
return im2col_kernel; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
cl_kernel get_im2col_nopad_kernel() |
|
|
|
|
{ |
|
|
|
|
static int init = 0; |
|
|
|
|
static cl_kernel im2col_kernel; |
|
|
|
|
if(!init){ |
|
|
|
|
im2col_kernel = get_kernel("src/im2col.cl", "im2col_nopad", 0); |
|
|
|
|
init = 1; |
|
|
|
|
} |
|
|
|
|
return im2col_kernel; |
|
|
|
@ -68,32 +79,34 @@ void im2col_ongpu(cl_mem data_im, int batch, |
|
|
|
|
int ksize, int stride, int pad, cl_mem data_col) |
|
|
|
|
{ |
|
|
|
|
cl_setup(); |
|
|
|
|
cl_kernel im2col_kernel = get_im2col_kernel(); |
|
|
|
|
cl_command_queue queue = cl.queue; |
|
|
|
|
|
|
|
|
|
cl_uint i = 0; |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(pad), (void*) &pad); |
|
|
|
|
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col); |
|
|
|
|
check_error(cl); |
|
|
|
|
|
|
|
|
|
int height_col = (height - ksize) / stride + 1; |
|
|
|
|
int width_col = (width - ksize) / stride + 1; |
|
|
|
|
int channels_col = channels * ksize * ksize; |
|
|
|
|
cl_kernel kernel = get_im2col_nopad_kernel(); |
|
|
|
|
|
|
|
|
|
if (pad){ |
|
|
|
|
height_col = 1 + (height-1) / stride; |
|
|
|
|
width_col = 1 + (width-1) / stride; |
|
|
|
|
kernel = get_im2col_pad_kernel(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
cl_command_queue queue = cl.queue; |
|
|
|
|
|
|
|
|
|
cl_uint i = 0; |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(data_im), (void*) &data_im); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(batch), (void*) &batch); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(channels), (void*) &channels); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(height), (void*) &height); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(width), (void*) &width); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(ksize), (void*) &ksize); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(stride), (void*) &stride); |
|
|
|
|
cl.error = clSetKernelArg(kernel, i++, sizeof(data_col), (void*) &data_col); |
|
|
|
|
check_error(cl); |
|
|
|
|
|
|
|
|
|
size_t global_size = batch*channels_col*height_col*width_col; |
|
|
|
|
|
|
|
|
|
clEnqueueNDRangeKernel(queue, im2col_kernel, 1, 0, |
|
|
|
|
clEnqueueNDRangeKernel(queue, kernel, 1, 0, |
|
|
|
|
&global_size, 0, 0, 0, 0); |
|
|
|
|
check_error(cl); |
|
|
|
|
} |
|
|
|
|