Compile fix

pull/1724/head
AlexeyAB 7 years ago
parent 007878393f
commit b141f85cab
  1. 26
      src/convolutional_kernels.cu
  2. 2
      src/convolutional_layer.c
  3. 27
      src/detector.c
  4. 6
      src/im2col.h
  5. 115
      src/im2col_kernels.cu

@ -117,6 +117,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
}
if(l.xnor){
if (!l.align_bit_weights_gpu || state.train) {
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
}
@ -128,6 +129,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
if (l.align_bit_weights_gpu && !state.train)
{
cudaError_t status = cudaSuccess;
int input_size = l.c*l.h*l.w*l.batch;
int m = l.n;
int k = l.size*l.size*l.c;
@ -139,6 +141,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
size_t t_intput_size = new_ldb * n;
size_t t_bit_input_size = t_intput_size / 8;// +1;
int i = 0;
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
//cudaDeviceSynchronize();
@ -152,13 +155,34 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
//cudaDeviceSynchronize();
// should be optimized
gemm_nn_custom_bin_mean_transposed_gpu(m, n, k,
(unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu);
//cudaDeviceSynchronize();
//check_error(status);
{
//float_to_bit_gpu(state.input, (unsigned char *)l.align_workspace_gpu, input_size);
/*
float *input_cpu = (float *)calloc(input_size, sizeof(float));
status = cudaMemcpy(input_cpu, state.input, input_size* sizeof(float), cudaMemcpyDeviceToHost);
check_error(status);
convolve_bin_cpu(input_cpu, l.weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad); // CPU
status = cudaMemcpy(l.output_gpu, l.output, l.outputs * sizeof(float), cudaMemcpyHostToDevice);
check_error(status);
free(input_cpu);
*/
//convolve_bin_gpu(state.input, l.weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad);
//cudaDeviceSynchronize();
//check_error(status);
}
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
if (l.binary || l.xnor) swap_binary(&l);

@ -637,6 +637,8 @@ void binary_align_weights(convolutional_layer *l)
check_error(status);
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
check_error(status);
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice);
check_error(status);
l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
cudaDeviceSynchronize();

@ -1026,19 +1026,24 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
char buff[1024];
FILE* fw = fopen("anchors.txt", "wb");
printf("\nSaving anchors to the file: anchors.txt \n");
printf("anchors = ");
for (i = 0; i < num_of_clusters; ++i) {
sprintf(buff, "%2.4f,%2.4f", centers->data.fl[i * 2], centers->data.fl[i * 2 + 1]);
printf("%s", buff);
fwrite(buff, sizeof(char), strlen(buff), fw);
if (i + 1 < num_of_clusters) {
fwrite(", ", sizeof(char), 2, fw);
printf(", ");
if (fw) {
printf("\nSaving anchors to the file: anchors.txt \n");
printf("anchors = ");
for (i = 0; i < num_of_clusters; ++i) {
sprintf(buff, "%2.4f,%2.4f", centers->data.fl[i * 2], centers->data.fl[i * 2 + 1]);
printf("%s", buff);
fwrite(buff, sizeof(char), strlen(buff), fw);
if (i + 1 < num_of_clusters) {
fwrite(", ", sizeof(char), 2, fw);
printf(", ");
}
}
printf("\n");
fclose(fw);
}
else {
printf(" Error: file anchors.txt can't be open \n");
}
printf("\n");
fclose(fw);
if (show) {
size_t img_size = 700;

@ -1,6 +1,8 @@
#ifndef IM2COL_H
#define IM2COL_H
#include <stddef.h>
void im2col_cpu(float* data_im,
int channels, int height, int width,
int ksize, int stride, int pad, float* data_col);
@ -27,5 +29,9 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr);
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
#endif
#endif

@ -461,3 +461,118 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
C, ldc,
mean_arr);
}
// --------------------------------
__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
{
int index = blockIdx.x*blockDim.x + threadIdx.x;
int fil;
// filter index
//for (fil = 0; fil < n; ++fil)
int chan, y, x, f_y, f_x;
// channel index
//for (chan = 0; chan < in_c; ++chan)
// input - y
//for (y = 0; y < in_h; ++y)
// input - x
//for (x = 0; x < in_w; ++x)
x = index % in_w;
int index2 = index / in_w;
y = index2 % in_h;
fil = index2 / in_h;
if (fil < n)
{
int const output_index = fil*in_w*in_h + y*in_w + x;
float sum = 0;
for (chan = 0; chan < in_c; ++chan)
{
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
int const input_pre_index = chan*in_w*in_h;
// filter - y
for (f_y = 0; f_y < size; ++f_y)
{
int input_y = y + f_y - pad;
// filter - x
for (f_x = 0; f_x < size; ++f_x)
{
int input_x = x + f_x - pad;
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
int input_index = input_pre_index + input_y*in_w + input_x;
int weights_index = weights_pre_index + f_y*size + f_x;
sum += input[input_index] *weights[weights_index];
}
}
// l.output[filters][width][height] +=
// state.input[channels][width][height] *
// l.weights[filters][channels][filter_width][filter_height];
//output[output_index] += sum;
}
output[output_index] = sum;
}
}
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
{
size_t array_size = in_w*in_h*n; // width X height X filters
const int num_blocks = array_size / BLOCK + 1;
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
convolve_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad);
}
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
{
int fil;
// filter index
#pragma omp parallel for // "omp parallel for" - automatic parallelization of loop by using OpenMP
for (fil = 0; fil < n; ++fil) {
int chan, y, x, f_y, f_x;
// channel index
for (chan = 0; chan < in_c; ++chan)
// input - y
for (y = 0; y < in_h; ++y)
// input - x
for (x = 0; x < in_w; ++x)
{
int const output_index = fil*in_w*in_h + y*in_w + x;
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
int const input_pre_index = chan*in_w*in_h;
float sum = 0;
// filter - y
for (f_y = 0; f_y < size; ++f_y)
{
int input_y = y + f_y - pad;
// filter - x
for (f_x = 0; f_x < size; ++f_x)
{
int input_x = x + f_x - pad;
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
int input_index = input_pre_index + input_y*in_w + input_x;
int weights_index = weights_pre_index + f_y*size + f_x;
sum += input[input_index] * weights[weights_index];
}
}
// l.output[filters][width][height] +=
// state.input[channels][width][height] *
// l.weights[filters][channels][filter_width][filter_height];
output[output_index] += sum;
}
}
}
Loading…
Cancel
Save