diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index a7387d01..4343aca9 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -127,136 +127,94 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) if (l.align_bit_weights_gpu && !state.train) { cudaError_t status; - status = cudaMemcpy(l.align_bit_weights, l.align_bit_weights_gpu, l.align_bit_weights_size, cudaMemcpyDeviceToHost); - check_error(status); + //status = cudaMemcpy(l.align_bit_weights, l.align_bit_weights_gpu, l.align_bit_weights_size, cudaMemcpyDeviceToHost); + //check_error(status); - float *input = (float *)calloc(l.c*l.h*l.w*l.batch, sizeof(float)); - float *workspace = (float *)calloc(l.bit_align*l.size*l.size*l.c, sizeof(float)); - float *output = (float *)calloc(l.batch*l.out_c*l.out_h*l.out_w, sizeof(float)); + //float *input = (float *)calloc(l.c*l.h*l.w*l.batch, sizeof(float)); + //float *workspace = (float *)calloc(l.bit_align*l.size*l.size*l.c, sizeof(float)); + //float *output = (float *)calloc(l.batch*l.out_c*l.out_h*l.out_w, sizeof(float)); - status = cudaMemcpy(input, state.input, l.c*l.h*l.w*l.batch*sizeof(float), cudaMemcpyDeviceToHost); - check_error(status); + //status = cudaMemcpy(input, state.input, l.c*l.h*l.w*l.batch*sizeof(float), cudaMemcpyDeviceToHost); + //check_error(status); int m = l.n; int k = l.size*l.size*l.c; int n = l.out_w*l.out_h; float * a = l.weights_gpu; //float * b = state.workspace; - float *b = workspace; + //float *b = workspace; //float * c = l.output_gpu; - float *c = output; + //float *c = output; int ldb_align = l.lda_align; size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; size_t t_intput_size = new_ldb * n; size_t t_bit_input_size = t_intput_size / 8;// +1; - char *t_bit_input = (char *)calloc(t_bit_input_size, sizeof(char)); - int src_size = k * l.bit_align; + //char *t_bit_input = (char *)calloc(t_bit_input_size, sizeof(char)); + //int src_size = k * l.bit_align; //im2col_cpu_custom_bin(input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); - float *align_workspace = NULL; - int align_workspace_size = l.bit_align * k; // aligned: n*k - status = cudaMalloc((void **)&align_workspace, align_workspace_size*sizeof(float)); - check_error(status); + //float *align_workspace = NULL; + //int align_workspace_size = l.bit_align * k; // aligned: n*k + //status = cudaMalloc((void **)&align_workspace, align_workspace_size*sizeof(float)); + //check_error(status); int i = 0; - im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, align_workspace, l.bit_align); + im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align); - float_to_bit_gpu(align_workspace, (unsigned char *)state.workspace, align_workspace_size); + float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size); if(1) { - { - /* - status = cudaMemcpy(t_bit_input, state.workspace, t_bit_input_size, cudaMemcpyDeviceToHost); - check_error(status); - for (int y = 0; y < 8; ++y) { - for (int x = 0; x < 8; ++x) { - int index = x + y*l.bit_align; - if (get_bit((unsigned char *)t_bit_input, index)) printf("1, "); - else printf("0, "); - } - printf("\n"); - } - printf("\n"); - */ - } - - fill_int8_gpu((unsigned char *)align_workspace, 0, t_bit_input_size); - - transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)align_workspace, k, n, l.bit_align, new_ldb, 8); + fill_int8_gpu((unsigned char *)l.align_workspace_gpu, 0, t_bit_input_size); + + transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.align_workspace_gpu, k, n, l.bit_align, new_ldb, 8); //cudaDeviceSynchronize(); //int size_transposed_array = l.bit_align * new_ldb; - status = cudaMemcpy(t_bit_input, align_workspace, t_bit_input_size, cudaMemcpyDeviceToHost); - check_error(status); - - /* - for (int y = 0; y < 8; ++y) { - for (int x = 0; x < 8; ++x) { - int index = x + y*new_ldb; - if (get_bit((unsigned char *)t_bit_input, index)) printf("1, "); - else printf("0, "); - } - printf("\n"); - } - printf("-----------\n"); - //getchar(); - */ + //status = cudaMemcpy(t_bit_input, l.align_workspace_gpu, t_bit_input_size, cudaMemcpyDeviceToHost); + //check_error(status); } + /* if (0) { - status = cudaMemcpy(b, state.workspace, align_workspace_size / 8, cudaMemcpyDeviceToHost); + status = cudaMemcpy(b, state.workspace, l.align_workspace_size / 8, cudaMemcpyDeviceToHost); check_error(status); - - for (int y = 0; y < 8; ++y) { - for (int x = 0; x < 8; ++x) { - int index = x + y*l.bit_align; - if (get_bit((unsigned char *)b, index)) printf("1, "); - else printf("0, "); - } - printf("\n"); - } - printf("\n"); - - //float *im2 = (float *)calloc(align_workspace_size, sizeof(float)); - //status = cudaMemcpy(im2, align_workspace, align_workspace_size * sizeof(float), cudaMemcpyDeviceToHost); + //float *im2 = (float *)calloc(l.align_workspace_size, sizeof(float)); + //status = cudaMemcpy(im2, l.align_workspace_gpu, l.align_workspace_size * sizeof(float), cudaMemcpyDeviceToHost); //check_error(status); - //float_to_bit(im2, (unsigned char *)b, align_workspace_size); + //float_to_bit(im2, (unsigned char *)b, l.align_workspace_size); memset(t_bit_input, 0, t_bit_input_size); // b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size // t_input - [bit_align, k] - [n', k] // t_bit_input - [new_ldb, n] - [k', n] transpose_bin((char *)b, t_bit_input, k, n, l.bit_align, new_ldb, 8); - - for (int y = 0; y < 8; ++y) { - for (int x = 0; x < 8; ++x) { - int index = x + y*new_ldb; - if (get_bit((unsigned char *)t_bit_input, index)) printf("1, "); - else printf("0, "); - } - printf("\n"); - } - printf("-----------\n"); - //getchar(); - - //free(im2); } + */ - // 5x times faster than gemm()-float32 - gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char *)l.align_bit_weights, new_ldb, (unsigned char *)t_bit_input, new_ldb, c, n, l.mean_arr); + //status = cudaMemcpy(l.align_bit_weights, l.align_bit_weights_gpu, new_ldb * m / 8, cudaMemcpyDeviceToHost); + //check_error(status); - status = cudaMemcpy(l.output_gpu, output, l.batch*l.out_c*l.out_h*l.out_w * sizeof(float), cudaMemcpyHostToDevice); - check_error(status); + //status = cudaMemcpy(l.mean_arr, l.mean_arr_gpu, l.n * sizeof(float), cudaMemcpyDeviceToHost); + //check_error(status); - free(t_bit_input); - free(input); - free(workspace); - free(output); - cudaFree(align_workspace); + // 5x times faster than gemm()-float32 + //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char *)l.align_bit_weights, new_ldb, (unsigned char *)t_bit_input, new_ldb, c, n, l.mean_arr); + //status = cudaMemcpy(l.output_gpu, output, l.batch*l.out_c*l.out_h*l.out_w * sizeof(float), cudaMemcpyHostToDevice); + //check_error(status); + + gemm_nn_custom_bin_mean_transposed_gpu(m, n, k, 1, + (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu); + //cudaDeviceSynchronize(); + + //free(t_bit_input); + //free(input); + //free(workspace); + //free(output); + //cudaFree(align_workspace); add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index fa927180..c36903dd 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -626,8 +626,13 @@ void binary_align_weights(convolutional_layer *l) get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr); #ifdef GPU + cudaError_t status; + l->align_workspace_size = l->bit_align * l->size * l->size * l->c; + status = cudaMalloc((void **)&l->align_workspace_gpu, l->align_workspace_size * sizeof(float)); + check_error(status); + //l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float)); - cudaError_t status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size); + status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size); check_error(status); status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice); check_error(status); diff --git a/src/im2col.h b/src/im2col.h index 435a59cd..5af54157 100644 --- a/src/im2col.h +++ b/src/im2col.h @@ -22,5 +22,10 @@ void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const in void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size); +void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, float ALPHA_UNUSED, + unsigned char *A, int lda, + unsigned char *B, int ldb, + float *C, int ldc, float *mean_arr); + #endif #endif diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index eee820af..6ba503c7 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -218,7 +218,6 @@ __global__ void transpose_bin_gpu_kernel(unsigned char *A, unsigned char *B, con if (j < m - 8) { int a_index = i*lda + j; int b_index = j*ldb + i; - //transpose_8x8_bits_my(&A[a_index/8], &B[b_index/8], lda/8, ldb/8); transpose8rS32_reversed_diagonale(&A[a_index / 8], lda / 8, ldb / 8, &B[b_index / 8]); } else if (j < m) { @@ -250,5 +249,79 @@ __global__ void fill_int8_gpu_kernel(unsigned char *src, unsigned char val, size void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size) { const int num_blocks = size / BLOCK + 1; - fill_int8_gpu_kernel<<>>(src, val, size); + fill_int8_gpu_kernel<<>>(src, val, size); +} + +// -------------------------------- + +typedef unsigned long long int uint64_t; + +__device__ __host__ static inline uint64_t xnor_int64(uint64_t a, uint64_t b) { + return ~(a^b); +} + +__global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int K, float ALPHA_UNUSED, + unsigned char *A, int lda, + unsigned char *B, int ldb, + float *C, int ldc, float *mean_arr) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + + //if (index == 0) + { + int i, j, k, h; + + //#pragma omp parallel for + //for (i = 0; i < M; ++i) + i = index % M; + //if(i < M) + { // l.n - filters [16 - 55 - 1024] + float mean_val = mean_arr[i]; + + //for (j = 0; j < N; ++j) + j = index / M; + if(j < N) + { // out_h*out_w - one channel output size [169 - 173056] + int count = 0; + + for (k = 0; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216] + uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); + uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); + uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64); + + //#ifdef WIN32 + // int tmp_count = __popcnt64(c_bit64); + //#else + // int tmp_count = __builtin_popcountll(c_bit64); + //#endif + + int tmp_count = __popcll(c_bit64); + + if (K - k < 64) tmp_count = tmp_count - (64 - (K - k)); // remove extra bits + count += tmp_count; + //binary_int64_printf(c_bit64); + //printf(", count = %d \n\n", tmp_count); + } + + C[i*ldc + j] = (2 * count - K) * mean_val; + } + } + } +} + + +void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, float ALPHA_UNUSED, + unsigned char *A, int lda, + unsigned char *B, int ldb, + float *C, int ldc, float *mean_arr) +{ + size_t size = M*N; + const int num_blocks = size / BLOCK + 1; + + gemm_nn_custom_bin_mean_transposed_gpu_kernel<<>>( + M, N, K, ALPHA_UNUSED, + A, lda, + B, ldb, + C, ldc, + mean_arr); } \ No newline at end of file diff --git a/src/layer.h b/src/layer.h index d533a5a5..271a13de 100644 --- a/src/layer.h +++ b/src/layer.h @@ -181,6 +181,9 @@ struct layer{ char *align_bit_weights_gpu; float *mean_arr_gpu; + float *align_workspace_gpu; + int align_workspace_size; + char *align_bit_weights; float *mean_arr; int align_bit_weights_size;