CUDA_CHECK definition for debug

pull/2352/head
AlexeyAB 6 years ago
parent 61156239e0
commit 584f840b40
  1. 83
      src/blas_kernels.cu
  2. 13
      src/cuda.c
  3. 2
      src/darknet.c
  4. 22
      src/im2col_kernels.cu

@ -25,7 +25,7 @@ void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
dim3 dimBlock(BLOCK, 1, 1);
scale_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
@ -51,7 +51,7 @@ __global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, in
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
{
backward_scale_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(x_norm, delta, batch, n, size, scale_updates);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
@ -69,7 +69,7 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
dim3 dimBlock(BLOCK, 1, 1);
add_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
@ -130,14 +130,14 @@ __global__ void dot_kernel(float *output, float scale, int batch, int n, int siz
void dot_error_gpu(layer l)
{
dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK, 0, get_cuda_stream()>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
*/
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
backward_bias_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(bias_updates, delta, batch, n, size);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
@ -154,7 +154,7 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
adam_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, x, m, v, B1, B2, rate, eps, t);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t)
@ -169,6 +169,7 @@ extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1
adam_gpu(n, w, m, v, B1, B2, rate, eps, t);
fill_ongpu(n, 0, d, 1);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
@ -193,7 +194,7 @@ extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, floa
{
size_t N = batch*filters*spatial;
normalize_delta_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
@ -298,19 +299,19 @@ __global__ void mean_delta_kernel(float *delta, float *variance, int batch, int
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
{
mean_delta_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
{
fast_mean_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
{
fast_variance_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
@ -457,7 +458,7 @@ extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch,
{
size_t N = batch*filters*spatial;
normalize_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, x, mean, variance, batch, filters, spatial);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
@ -520,26 +521,26 @@ __global__ void fast_variance_kernel(float *x, float *mean, int batch, int filt
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
fast_mean_kernel<<<filters, BLOCK, 0, get_cuda_stream()>>>(x, batch, filters, spatial, mean);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
mean_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, batch, filters, spatial, mean);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
variance_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
@ -550,13 +551,13 @@ extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, i
extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
{
pow_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX, Y, INCY);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
axpy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY)
@ -568,19 +569,19 @@ extern "C" void simple_copy_ongpu(int size, float *src, float *dst)
{
const int num_blocks = size / BLOCK + 1;
simple_copy_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(size, src, dst);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
{
mul_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, INCX, Y, INCY);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
copy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
@ -604,57 +605,57 @@ extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int
{
int size = spatial*batch*layers;
flatten_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, spatial, layers, batch, forward, out);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
{
int size = w*h*c*batch;
reorg_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, w, h, c, batch, stride, forward, out);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val)
{
mask_kernel_new_api <<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask, val);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
{
mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
{
const_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
{
constrain_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
{
scal_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
{
supp_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
{
fill_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
@ -689,7 +690,7 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
int size = batch * minw * minh * minc;
shortcut_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void simple_input_shortcut_kernel(float *in, int size, float *add, float *out)
@ -739,7 +740,7 @@ extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1,
int size = batch * minw * minh * minc;
input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
@ -762,7 +763,7 @@ __global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta,
extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
smooth_l1_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *delta, float *error)
@ -779,7 +780,7 @@ __global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *de
extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
softmax_x_ent_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, pred, truth, delta, error);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
@ -795,7 +796,7 @@ __global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float
extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
l2_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
@ -811,7 +812,7 @@ __global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *
extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
{
weighted_sum_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, c);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
@ -827,7 +828,7 @@ __global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float
extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
{
weighted_delta_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, da, db, ds, dc);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
@ -841,7 +842,7 @@ __global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
{
mult_add_into_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, c);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
@ -876,7 +877,7 @@ extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float t
int inputs = n;
int batch = groups;
softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__device__ void softmax_device_new_api(float *input, int n, float temp, int stride, float *output)
@ -910,7 +911,7 @@ __global__ void softmax_kernel_new_api(float *input, int n, int batch, int batch
extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
{
softmax_kernel_new_api << <cuda_gridsize(batch*groups), BLOCK, 0, get_cuda_stream() >> >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
@ -942,7 +943,7 @@ extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stri
{
size_t size = w*h*c*batch*stride*stride;
upsample_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(size, in, w, h, c, batch, stride, forward, scale, out);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset)
@ -972,7 +973,7 @@ extern "C" void softmax_tree_gpu(float *input, int spatial, int batch, int strid
*/
int num = spatial*batch*hier.groups;
softmax_tree_kernel <<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
check_error(cudaPeekAtLastError());
CHECK_CUDA(cudaPeekAtLastError());
cuda_free((float *)tree_groups_size);
cuda_free((float *)tree_groups_offset);
}

@ -26,9 +26,6 @@ int cuda_get_device()
void check_error(cudaError_t status)
{
#ifdef DEBUG
cudaDeviceSynchronize();
#endif
cudaError_t status2 = cudaGetLastError();
if (status != cudaSuccess)
{
@ -58,6 +55,11 @@ void check_error(cudaError_t status)
void check_error_extended(cudaError_t status, const char *file, int line, const char *date_time)
{
if (status != cudaSuccess)
printf("CUDA Prev Error: file: %s() : line: %d : build time: %s \n", file, line, date_time);
#ifdef DEBUG
status = cudaDeviceSynchronize();
#endif
if (status != cudaSuccess)
printf("CUDA Error: file: %s() : line: %d : build time: %s \n", file, line, date_time);
check_error(status);
@ -171,6 +173,11 @@ void cudnn_check_error(cudnnStatus_t status)
void cudnn_check_error_extended(cudnnStatus_t status, const char *file, int line, const char *date_time)
{
if (status != cudaSuccess)
printf("\n cuDNN Prev Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time);
#ifdef DEBUG
status = cudaDeviceSynchronize();
#endif
if (status != cudaSuccess)
printf("\n cuDNN Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time);
cudnn_check_error(status);

@ -458,7 +458,7 @@ int main(int argc, char **argv)
#else
if(gpu_index >= 0){
cuda_set_device(gpu_index);
check_error(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
}
#endif

@ -86,6 +86,8 @@ void im2col_ongpu(float *im,
num_kernels, im, height, width, ksize, pad,
stride, height_col,
width_col, data_col);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -219,6 +221,8 @@ void im2col_align_ongpu(float *im,
num_kernels, im, height, width, ksize, pad,
stride, height_col,
width_col, data_col, bit_align);
CHECK_CUDA(cudaPeekAtLastError());
}
@ -346,6 +350,8 @@ void im2col_align_bin_ongpu(float *im,
num_kernels, im, height, width, ksize, channels, pad,
stride, height_col,
width_col, data_col, bit_align);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -436,6 +442,7 @@ void float_to_bit_gpu(float *src, unsigned char *dst, size_t size)
//const int num_blocks = size / (32*1024) + 1;
const int num_blocks = get_number_of_blocks(size, 32 * 1024);
float_to_bit_gpu_kernel<<<num_blocks, 1024, 0, get_cuda_stream()>>>(src, dst, size);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -653,6 +660,7 @@ void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const in
const int num_blocks32 = size32 / BLOCK_TRANSPOSE32 + 1;
transpose_bin_gpu_kernel_32 << <num_blocks32, BLOCK_TRANSPOSE32, 0, get_cuda_stream() >> >((uint32_t *)A, (uint32_t *)B, n, m, lda, ldb, block_size);
//transpose_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(A, B, n, m, lda, ldb, block_size);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -679,6 +687,7 @@ void transpose_uint32_gpu(uint32_t *src, uint32_t *dst, int src_h, int src_w, in
int size = src_w * src_h;
const int num_blocks = size / BLOCK + 1;
transpose_uint32_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(src, dst, src_h, src_w, src_align, dst_align);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -744,6 +753,7 @@ void transpose_uint32_gpu_2(uint32_t *src, uint32_t *dst, int src_h, int src_w,
int size = src_w_align * src_h_align;
int num_blocks = size / TRANS_BLOCK;
transpose_uint32_kernel_2 << <num_blocks, TRANS_BLOCK, 0, get_cuda_stream() >> >(src, dst, src_h, src_w, src_align, dst_align);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -781,6 +791,7 @@ void repack_input_gpu(float *input, float *re_packed_input, int w, int h, int c)
int size = w * h * c;
const int num_blocks = size / BLOCK + 1;
repack_input_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(input, re_packed_input, w, h, c);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -820,6 +831,7 @@ void repack_input_gpu_2(float *input, float *re_packed_input, int w, int h, int
int size = w * h * c;
const int num_blocks = size / BLOCK + 1;
repack_input_kernel_2 << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(input, re_packed_input, w, h, c);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -869,6 +881,7 @@ void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, in
const int num_blocks = get_number_of_blocks(size, block_size);
//printf("\n num_blocks = %d, num_blocks/32 = %d, block_size = %d \n", num_blocks, num_blocks / 32, block_size);
repack_input_kernel_bin << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, re_packed_input_bin, w, h, c);
CHECK_CUDA(cudaPeekAtLastError());
}
/*
@ -919,6 +932,7 @@ void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, in
const int num_blocks = get_number_of_blocks(size, block_size);
printf("\n num_blocks = %d, num_blocks/32 = %d, block_size = %d \n", num_blocks, num_blocks/32, block_size);
repack_input_kernel_bin << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, re_packed_input_bin, w, h, c);
CHECK_CUDA(cudaPeekAtLastError());
}
*/
@ -932,6 +946,7 @@ __global__ void fill_int8_gpu_kernel(unsigned char *src, unsigned char val, size
void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size) {
const int num_blocks = size / BLOCK + 1;
fill_int8_gpu_kernel<<<num_blocks, BLOCK, 0, get_cuda_stream()>>>(src, val, size);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -1761,7 +1776,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
*/
//printf(" shared_memory: (w) lda*BLOCK/N = %d, (i) ldb*BLOCK/M = %d, \t lda = %d \n\n", lda*BLOCK / N, ldb*BLOCK / M, lda);
#if CUDART_VERSION >= 10000
//if (M % 8 == 0 && N % 8 == 0 && M == 128)
//if (M >= 32) // l.n >= 32
if (1)
@ -1784,7 +1799,6 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
//getchar();
}
else
#endif // CUDART_VERSION >= 10000
{
gemm_nn_custom_bin_mean_transposed_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (
M, N, K,
@ -1793,7 +1807,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
C, ldc,
mean_arr, bias, leaky_activation);
}
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -1973,6 +1987,7 @@ void convolve_gpu(float *input, float *weights, float *output, int in_w, int in_
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
convolve_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------
@ -2188,6 +2203,7 @@ void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
convolve_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad, new_lda, mean_arr_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
// --------------------------------

Loading…
Cancel
Save