Added dilation parameter for convolutional, conv_lstm and crnn layers

pull/3367/head
AlexeyAB 6 years ago
parent 12db38ccbf
commit 88ce9dcca6
  1. 1
      include/darknet.h
  2. 53
      src/col2im.c
  3. 14
      src/col2im.h
  4. 79
      src/col2im_kernels.cu
  5. 25
      src/conv_lstm_layer.c
  6. 2
      src/conv_lstm_layer.h
  7. 34
      src/convolutional_kernels.cu
  8. 49
      src/convolutional_layer.c
  9. 2
      src/convolutional_layer.h
  10. 9
      src/crnn_layer.c
  11. 2
      src/crnn_layer.h
  12. 53
      src/im2col.c
  13. 14
      src/im2col.h
  14. 72
      src/im2col_kernels.cu
  15. 9
      src/parser.c

@ -204,6 +204,7 @@ struct layer {
int size; int size;
int side; int side;
int stride; int stride;
int dilation;
int reverse; int reverse;
int flatten; int flatten;
int spatial; int spatial;

@ -37,3 +37,56 @@ void col2im_cpu(float* data_col,
} }
} }
} }
// ----------------------------------------
void caffe_set(const int N, const float alpha, float* Y) {
if (alpha == 0) {
memset(Y, 0, sizeof(float) * N); // NOLINT(caffe/alt_fn)
return;
}
for (int i = 0; i < N; ++i) {
Y[i] = alpha;
}
}
inline int is_a_ge_zero_and_a_lt_b(int a, int b) {
return (unsigned)(a) < (unsigned)(b);
}
// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp
void col2im_cpu_ext(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_im)
{
caffe_set(height * width * channels, 0.0F, data_im);
const int output_h = (height + 2 * pad_h -
(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w = (width + 2 * pad_w -
(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w;
}
else {
int input_col = -pad_w + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col;
}
data_col++;
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}

@ -8,10 +8,24 @@ void col2im_cpu(float* data_col,
int channels, int height, int width, int channels, int height, int width,
int ksize, int stride, int pad, float* data_im); int ksize, int stride, int pad, float* data_im);
void col2im_cpu_ext(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_im);
#ifdef GPU #ifdef GPU
void col2im_ongpu(float *data_col, void col2im_ongpu(float *data_col,
int channels, int height, int width, int channels, int height, int width,
int ksize, int stride, int pad, float *data_im); int ksize, int stride, int pad, float *data_im);
void col2im_gpu_ext(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
float* data_im);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
} }

@ -55,3 +55,82 @@ void col2im_ongpu(float *data_col,
CHECK_CUDA(cudaPeekAtLastError()); CHECK_CUDA(cudaPeekAtLastError());
} }
// -----------------------------------------
// CUDA: use 512 threads per block
const int CAFFE_CUDA_NUM_THREADS = 512;
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu
__global__ void col2im_gpu_kernel_ext(const int n, const float* data_col,
const int height, const int width, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int height_col, const int width_col,
float* data_im) {
CUDA_KERNEL_LOOP(index, n) {
float val = 0;
const int w_im = index % width + pad_w;
const int h_im = (index / width) % height + pad_h;
const int c_im = index / (width * height);
int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
// compute the start and end of the output
const int w_col_start =
(w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
const int w_col_end = min(w_im / stride_w + 1, width_col);
const int h_col_start =
(h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
const int h_col_end = min(h_im / stride_h + 1, height_col);
// TODO: use LCM of stride and dilation to avoid unnecessary loops
for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
int h_k = (h_im - h_col * stride_h);
int w_k = (w_im - w_col * stride_w);
if (h_k % dilation_h == 0 && w_k % dilation_w == 0) {
h_k /= dilation_h;
w_k /= dilation_w;
int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) *
height_col + h_col) * width_col + w_col;
val += data_col[data_col_index];
}
}
}
data_im[index] = val;
}
}
void col2im_gpu_ext(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
float* data_im)
{
int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) /
stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) /
stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
// NOLINT_NEXT_LINE(whitespace/operators)
col2im_gpu_kernel_ext<< <CAFFE_GET_BLOCKS(num_kernels),
CAFFE_CUDA_NUM_THREADS >> >(
num_kernels, data_col, height, width, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
height_col, width_col, data_im);
CHECK_CUDA(cudaPeekAtLastError());
}

@ -32,7 +32,7 @@ static void increment_layer(layer *l, int steps)
} }
layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor) layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor)
{ {
fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters); fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters);
/* /*
@ -53,6 +53,7 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
l.steps = steps; l.steps = steps;
l.size = size; l.size = size;
l.stride = stride; l.stride = stride;
l.dilation = dilation;
l.pad = pad; l.pad = pad;
l.h = h; l.h = h;
l.w = w; l.w = w;
@ -65,44 +66,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// U // U
l.uf = (layer*)calloc(1, sizeof(layer)); l.uf = (layer*)calloc(1, sizeof(layer));
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.uf->batch = batch; l.uf->batch = batch;
if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size;
l.ui = (layer*)calloc(1, sizeof(layer)); l.ui = (layer*)calloc(1, sizeof(layer));
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.ui->batch = batch; l.ui->batch = batch;
if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size;
l.ug = (layer*)calloc(1, sizeof(layer)); l.ug = (layer*)calloc(1, sizeof(layer));
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.ug->batch = batch; l.ug->batch = batch;
if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size;
l.uo = (layer*)calloc(1, sizeof(layer)); l.uo = (layer*)calloc(1, sizeof(layer));
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.uo->batch = batch; l.uo->batch = batch;
if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size;
// W // W
l.wf = (layer*)calloc(1, sizeof(layer)); l.wf = (layer*)calloc(1, sizeof(layer));
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.wf->batch = batch; l.wf->batch = batch;
if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size;
l.wi = (layer*)calloc(1, sizeof(layer)); l.wi = (layer*)calloc(1, sizeof(layer));
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.wi->batch = batch; l.wi->batch = batch;
if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size;
l.wg = (layer*)calloc(1, sizeof(layer)); l.wg = (layer*)calloc(1, sizeof(layer));
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.wg->batch = batch; l.wg->batch = batch;
if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size;
l.wo = (layer*)calloc(1, sizeof(layer)); l.wo = (layer*)calloc(1, sizeof(layer));
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.wo->batch = batch; l.wo->batch = batch;
if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size;
@ -110,21 +111,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// V // V
l.vf = (layer*)calloc(1, sizeof(layer)); l.vf = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.vf->batch = batch; l.vf->batch = batch;
if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size; if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size;
} }
l.vi = (layer*)calloc(1, sizeof(layer)); l.vi = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.vi->batch = batch; l.vi->batch = batch;
if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size; if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size;
} }
l.vo = (layer*)calloc(1, sizeof(layer)); l.vo = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.vo->batch = batch; l.vo->batch = batch;
if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size; if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size;
} }

@ -9,7 +9,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor); layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor);
void resize_conv_lstm_layer(layer *l, int w, int h); void resize_conv_lstm_layer(layer *l, int w, int h);
void free_state_conv_lstm(layer l); void free_state_conv_lstm(layer l);
void randomize_state_conv_lstm(layer l); void randomize_state_conv_lstm(layer l);

@ -566,7 +566,17 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
b = im; b = im;
} }
else { else {
im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); //im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
im2col_gpu_ext(im, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride, l.stride, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.workspace); // output
} }
gemm_ongpu(0, 0, m, n, k, 1., a, k, b, n, 1., c + i*m*n, n); gemm_ongpu(0, 0, m, n, k, 1., a, k, b, n, 1., c + i*m*n, n);
} }
@ -798,7 +808,15 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w;
im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); //im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
im2col_gpu_ext(im, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride, l.stride, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.workspace); // output
gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n); gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n);
if (state.delta) { if (state.delta) {
@ -811,7 +829,17 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
float *delta = state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w; float *delta = state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w;
col2im_ongpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, delta); //col2im_ongpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, delta);
col2im_gpu_ext(
state.workspace, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding size (h, w)
l.stride, l.stride, // stride size (h, w)
l.dilation, l.dilation, // dilation size (h, w)
delta); // output (delta)
if (l.binary || l.xnor) { if (l.binary || l.xnor) {
swap_binary(&l); swap_binary(&l);
} }

@ -275,9 +275,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w));
#if(CUDNN_MAJOR >= 6) #if(CUDNN_MAJOR >= 6)
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0 CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, l->dilation, l->dilation, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0
#else #else
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION)); // cudnn 5.1 CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, l->dilation, l->dilation, CUDNN_CROSS_CORRELATION)); // cudnn 5.1
#endif #endif
int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
@ -331,7 +331,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
#endif #endif
#endif #endif
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index)
{ {
int total_batch = batch*steps; int total_batch = batch*steps;
int i; int i;
@ -353,6 +353,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.batch = batch; l.batch = batch;
l.steps = steps; l.steps = steps;
l.stride = stride; l.stride = stride;
l.dilation = dilation;
l.size = size; l.size = size;
l.pad = padding; l.pad = padding;
l.batch_normalize = batch_normalize; l.batch_normalize = batch_normalize;
@ -525,7 +526,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer() void test_convolutional_layer()
{ {
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0);
l.batch_normalize = 1; l.batch_normalize = 1;
float data[] = {1,1,1,1,1, float data[] = {1,1,1,1,1,
1,1,1,1,1, 1,1,1,1,1,
@ -981,8 +982,17 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
} }
else { else {
//printf(" l.index = %d - FP32 \n", l.index); //printf(" l.index = %d - FP32 \n", l.index);
im2col_cpu(state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w, float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w;
l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); //im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
im2col_cpu_ext(im, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride, l.stride, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
b); // output
gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);
// bit-count to float // bit-count to float
@ -1028,8 +1038,17 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w;
im2col_cpu(im, l.c / l.groups, l.h, l.w, //im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
l.size, l.stride, l.pad, b); im2col_cpu_ext(
im, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride, l.stride, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
b); // output
gemm(0, 1, m, n, k, 1, a, k, b, k, 1, c, n); gemm(0, 1, m, n, k, 1, a, k, b, k, 1, c, n);
if (state.delta) { if (state.delta) {
@ -1039,8 +1058,18 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
gemm(1, 0, n, k, m, 1, a, n, b, k, 0, c, k); gemm(1, 0, n, k, m, 1, a, n, b, k, 0, c, k);
col2im_cpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, //col2im_cpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride,
l.pad, state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); // l.pad, state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w);
col2im_cpu_ext(
state.workspace, // input
l.c / l.groups, // input channels (h, w)
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride, l.stride, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); // output (delta)
} }
} }
} }

@ -30,7 +30,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif #endif
size_t get_convolutional_workspace_size(layer l); size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index);
void denormalize_convolutional_layer(convolutional_layer l); void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state); void forward_convolutional_layer(const convolutional_layer layer, network_state state);

@ -26,7 +26,7 @@ static void increment_layer(layer *l, int steps)
#endif #endif
} }
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor) layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor)
{ {
fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters); fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters);
batch = batch / steps; batch = batch / steps;
@ -36,6 +36,7 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.steps = steps; l.steps = steps;
l.size = size; l.size = size;
l.stride = stride; l.stride = stride;
l.dilation = dilation;
l.pad = pad; l.pad = pad;
l.h = h; l.h = h;
l.w = w; l.w = w;
@ -49,17 +50,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float)); l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float));
l.input_layer = (layer*)calloc(1, sizeof(layer)); l.input_layer = (layer*)calloc(1, sizeof(layer));
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.input_layer->batch = batch; l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = (layer*)calloc(1, sizeof(layer)); l.self_layer = (layer*)calloc(1, sizeof(layer));
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.self_layer->batch = batch; l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = (layer*)calloc(1, sizeof(layer)); l.output_layer = (layer*)calloc(1, sizeof(layer));
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0);
l.output_layer->batch = batch; l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;

@ -9,7 +9,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor); layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor);
void resize_crnn_layer(layer *l, int w, int h); void resize_crnn_layer(layer *l, int w, int h);
void free_state_crnn(layer l); void free_state_crnn(layer l);

@ -37,3 +37,56 @@ void im2col_cpu(float* data_im,
} }
} }
} }
// Function uses casting from int to unsigned to compare if value of
// parameter a is greater or equal to zero and lower than value of
// parameter b. The b parameter is of type signed and is always positive,
// therefore its value is always lower than 0x800... where casting
// negative value of a parameter converts it to value higher than 0x800...
// The casting allows to use one condition instead of two.
inline int is_a_ge_zero_and_a_lt_b(int a, int b) {
return (unsigned)(a) < (unsigned)(b);
}
// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp
void im2col_cpu_ext(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_col)
{
const int output_h = (height + 2 * pad_h -
(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w = (width + 2 * pad_w -
(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
for (int output_cols = output_w; output_cols; output_cols--) {
*(data_col++) = 0;
}
}
else {
int input_col = -pad_w + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
*(data_col++) = data_im[input_row * width + input_col];
}
else {
*(data_col++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}

@ -14,12 +14,26 @@ void im2col_cpu(float* data_im,
float im2col_get_pixel(float* im, int height, int width, int channels, float im2col_get_pixel(float* im, int height, int width, int channels,
int row, int col, int channel, int pad); int row, int col, int channel, int pad);
void im2col_cpu_ext(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_col);
#ifdef GPU #ifdef GPU
void im2col_ongpu(float *im, void im2col_ongpu(float *im,
int channels, int height, int width, int channels, int height, int width,
int ksize, int stride, int pad,float *data_col); int ksize, int stride, int pad,float *data_col);
void im2col_gpu_ext(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_col);
void im2col_align_ongpu(float *im, void im2col_align_ongpu(float *im,
int channels, int height, int width, int channels, int height, int width,
int ksize, int stride, int pad, float *data_col, int bit_align); int ksize, int stride, int pad, float *data_col, int bit_align);

@ -2214,3 +2214,75 @@ void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int
} }
// -------------------------------- // --------------------------------
// CUDA: use 512 threads per block
const int CAFFE_CUDA_NUM_THREADS = 512;
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu
__global__ void im2col_gpu_kernel_ext(const int n, const float* data_im,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int height_col, const int width_col,
float* data_col) {
CUDA_KERNEL_LOOP(index, n) {
const int h_index = index / width_col;
const int h_col = h_index % height_col;
const int w_col = index % width_col;
const int c_im = h_index / height_col;
const int c_col = c_im * kernel_h * kernel_w;
const int h_offset = h_col * stride_h - pad_h;
const int w_offset = w_col * stride_w - pad_w;
float* data_col_ptr = data_col;
data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;
const float* data_im_ptr = data_im;
data_im_ptr += (c_im * height + h_offset) * width + w_offset;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
int h_im = h_offset + i * dilation_h;
int w_im = w_offset + j * dilation_w;
*data_col_ptr =
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
data_im_ptr[i * dilation_h * width + j * dilation_w] : 0;
data_col_ptr += height_col * width_col;
}
}
}
}
void im2col_gpu_ext(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
float* data_col)
{
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad_h -
(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w -
(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
// NOLINT_NEXT_LINE(whitespace/operators)
im2col_gpu_kernel_ext << <CAFFE_GET_BLOCKS(num_kernels),
CAFFE_CUDA_NUM_THREADS >> >(
num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h,
pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col,
width_col, data_col);
CHECK_CUDA(cudaPeekAtLastError());
}

@ -153,6 +153,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int groups = option_find_int_quiet(options, "groups", 1); int groups = option_find_int_quiet(options, "groups", 1);
int size = option_find_int(options, "size",1); int size = option_find_int(options, "size",1);
int stride = option_find_int(options, "stride",1); int stride = option_find_int(options, "stride",1);
int dilation = option_find_int_quiet(options, "dilation", 1);
int pad = option_find_int_quiet(options, "pad",0); int pad = option_find_int_quiet(options, "pad",0);
int padding = option_find_int_quiet(options, "padding",0); int padding = option_find_int_quiet(options, "padding",0);
if(pad) padding = size/2; if(pad) padding = size/2;
@ -171,7 +172,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index);
layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0); layer.dot = option_find_float_quiet(options, "dot", 0);
@ -188,6 +189,7 @@ layer parse_crnn(list *options, size_params params)
{ {
int size = option_find_int_quiet(options, "size", 3); int size = option_find_int_quiet(options, "size", 3);
int stride = option_find_int_quiet(options, "stride", 1); int stride = option_find_int_quiet(options, "stride", 1);
int dilation = option_find_int_quiet(options, "dilation", 1);
int pad = option_find_int_quiet(options, "pad", 0); int pad = option_find_int_quiet(options, "pad", 0);
int padding = option_find_int_quiet(options, "padding", 0); int padding = option_find_int_quiet(options, "padding", 0);
if (pad) padding = size / 2; if (pad) padding = size / 2;
@ -200,7 +202,7 @@ layer parse_crnn(list *options, size_params params)
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, xnor);
l.shortcut = option_find_int_quiet(options, "shortcut", 0); l.shortcut = option_find_int_quiet(options, "shortcut", 0);
@ -248,6 +250,7 @@ layer parse_conv_lstm(list *options, size_params params)
// a ConvLSTM with a larger transitional kernel should be able to capture faster motions // a ConvLSTM with a larger transitional kernel should be able to capture faster motions
int size = option_find_int_quiet(options, "size", 3); int size = option_find_int_quiet(options, "size", 3);
int stride = option_find_int_quiet(options, "stride", 1); int stride = option_find_int_quiet(options, "stride", 1);
int dilation = option_find_int_quiet(options, "dilation", 1);
int pad = option_find_int_quiet(options, "pad", 0); int pad = option_find_int_quiet(options, "pad", 0);
int padding = option_find_int_quiet(options, "padding", 0); int padding = option_find_int_quiet(options, "padding", 0);
if (pad) padding = size / 2; if (pad) padding = size / 2;
@ -260,7 +263,7 @@ layer parse_conv_lstm(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
int peephole = option_find_int_quiet(options, "peephole", 0); int peephole = option_find_int_quiet(options, "peephole", 0);
layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor); layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, peephole, xnor);
l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32); l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32);
l.shortcut = option_find_int_quiet(options, "shortcut", 0); l.shortcut = option_find_int_quiet(options, "shortcut", 0);

Loading…
Cancel
Save