Optimized memory allocation for Detection (inference only), without allocation memory for training

pull/4269/head
AlexeyAB 6 years ago
parent 0eee8404bf
commit d91d59a22f
  1. 1
      include/darknet.h
  2. 25
      src/conv_lstm_layer.c
  3. 2
      src/conv_lstm_layer.h
  4. 6
      src/convolutional_kernels.cu
  5. 90
      src/convolutional_layer.c
  6. 3
      src/convolutional_layer.h
  7. 9
      src/crnn_layer.c
  8. 2
      src/crnn_layer.h
  9. 4
      src/gemm.c
  10. 30
      src/maxpool_layer.c
  11. 2
      src/maxpool_layer.h
  12. 4
      src/maxpool_layer_kernels.cu
  13. 1
      src/network.c
  14. 14
      src/parser.c
  15. 14
      src/shortcut_layer.c
  16. 2
      src/shortcut_layer.h
  17. 2
      src/yolo_layer.c

@ -190,6 +190,7 @@ struct layer {
void(*backward_gpu) (struct layer, struct network_state); void(*backward_gpu) (struct layer, struct network_state);
void(*update_gpu) (struct layer, int, float, float, float); void(*update_gpu) (struct layer, int, float, float, float);
layer *share_layer; layer *share_layer;
int train;
int batch_normalize; int batch_normalize;
int shortcut; int shortcut;
int batch; int batch;

@ -32,7 +32,7 @@ static void increment_layer(layer *l, int steps)
} }
layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor) layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor, int train)
{ {
fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters); fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters);
/* /*
@ -48,6 +48,7 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
*/ */
batch = batch / steps; batch = batch / steps;
layer l = { (LAYER_TYPE)0 }; layer l = { (LAYER_TYPE)0 };
l.train = train;
l.batch = batch; l.batch = batch;
l.type = CONV_LSTM; l.type = CONV_LSTM;
l.steps = steps; l.steps = steps;
@ -66,44 +67,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// U // U
l.uf = (layer*)calloc(1, sizeof(layer)); l.uf = (layer*)calloc(1, sizeof(layer));
*(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.uf->batch = batch; l.uf->batch = batch;
if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size;
l.ui = (layer*)calloc(1, sizeof(layer)); l.ui = (layer*)calloc(1, sizeof(layer));
*(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.ui->batch = batch; l.ui->batch = batch;
if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size;
l.ug = (layer*)calloc(1, sizeof(layer)); l.ug = (layer*)calloc(1, sizeof(layer));
*(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.ug->batch = batch; l.ug->batch = batch;
if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size;
l.uo = (layer*)calloc(1, sizeof(layer)); l.uo = (layer*)calloc(1, sizeof(layer));
*(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.uo->batch = batch; l.uo->batch = batch;
if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size;
// W // W
l.wf = (layer*)calloc(1, sizeof(layer)); l.wf = (layer*)calloc(1, sizeof(layer));
*(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.wf->batch = batch; l.wf->batch = batch;
if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size;
l.wi = (layer*)calloc(1, sizeof(layer)); l.wi = (layer*)calloc(1, sizeof(layer));
*(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.wi->batch = batch; l.wi->batch = batch;
if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size;
l.wg = (layer*)calloc(1, sizeof(layer)); l.wg = (layer*)calloc(1, sizeof(layer));
*(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.wg->batch = batch; l.wg->batch = batch;
if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size;
l.wo = (layer*)calloc(1, sizeof(layer)); l.wo = (layer*)calloc(1, sizeof(layer));
*(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.wo->batch = batch; l.wo->batch = batch;
if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size;
@ -111,21 +112,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i
// V // V
l.vf = (layer*)calloc(1, sizeof(layer)); l.vf = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.vf->batch = batch; l.vf->batch = batch;
if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size; if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size;
} }
l.vi = (layer*)calloc(1, sizeof(layer)); l.vi = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.vi->batch = batch; l.vi->batch = batch;
if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size; if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size;
} }
l.vo = (layer*)calloc(1, sizeof(layer)); l.vo = (layer*)calloc(1, sizeof(layer));
if (l.peephole) { if (l.peephole) {
*(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.vo->batch = batch; l.vo->batch = batch;
if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size; if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size;
} }

@ -9,7 +9,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor); layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor, int train);
void resize_conv_lstm_layer(layer *l, int w, int h); void resize_conv_lstm_layer(layer *l, int w, int h);
void free_state_conv_lstm(layer l); void free_state_conv_lstm(layer l);
void randomize_state_conv_lstm(layer l); void randomize_state_conv_lstm(layer l);

@ -986,7 +986,8 @@ void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
} }
else { else {
if (iteration_num < state.net.burn_in) return; if (iteration_num < state.net.burn_in) return;
else if (iteration_num > l.assisted_excitation) return; else
if (iteration_num > l.assisted_excitation) return;
else else
alpha = (1 + cos(3.141592 * iteration_num / (state.net.burn_in + l.assisted_excitation))) / 2; // from 1 to 0 alpha = (1 + cos(3.141592 * iteration_num / (state.net.burn_in + l.assisted_excitation))) / 2; // from 1 to 0
} }
@ -1018,6 +1019,7 @@ void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
for (t = 0; t < state.net.num_boxes; ++t) { for (t = 0; t < state.net.num_boxes; ++t) {
box truth = float_to_box_stride(truth_cpu + t*(4 + 1) + b*l.truths, 1); box truth = float_to_box_stride(truth_cpu + t*(4 + 1) + b*l.truths, 1);
if (!truth.x) break; // continue; if (!truth.x) break; // continue;
//float beta = 0;
float beta = 1 - alpha; // from 0 to 1 float beta = 1 - alpha; // from 0 to 1
float dw = (1 - truth.w) * beta; float dw = (1 - truth.w) * beta;
float dh = (1 - truth.h) * beta; float dh = (1 - truth.h) * beta;
@ -1162,8 +1164,10 @@ void push_convolutional_layer(convolutional_layer l)
cuda_convert_f32_to_f16(l.weights_gpu, l.nweights, l.weights_gpu16); cuda_convert_f32_to_f16(l.weights_gpu, l.nweights, l.weights_gpu16);
#endif #endif
cuda_push_array(l.biases_gpu, l.biases, l.n); cuda_push_array(l.biases_gpu, l.biases, l.n);
if (l.train) {
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights); cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n); cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
}
if (l.batch_normalize){ if (l.batch_normalize){
cuda_push_array(l.scales_gpu, l.scales, l.n); cuda_push_array(l.scales_gpu, l.scales, l.n);
cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.n); cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.n);

@ -123,7 +123,7 @@ size_t get_workspace_size32(layer l){
l.dweightDesc, l.dweightDesc,
l.bf_algo, l.bf_algo,
&s)); &s));
if (s > most) most = s; if (s > most && l.train) most = s;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc, l.weightDesc,
l.ddstTensorDesc, l.ddstTensorDesc,
@ -131,7 +131,7 @@ size_t get_workspace_size32(layer l){
l.dsrcTensorDesc, l.dsrcTensorDesc,
l.bd_algo, l.bd_algo,
&s)); &s));
if (s > most) most = s; if (s > most && l.train) most = s;
return most; return most;
} }
#endif #endif
@ -164,7 +164,7 @@ size_t get_workspace_size16(layer l) {
l.dweightDesc16, l.dweightDesc16,
l.bf_algo16, l.bf_algo16,
&s)); &s));
if (s > most) most = s; if (s > most && l.train) most = s;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc16, l.weightDesc16,
l.ddstTensorDesc16, l.ddstTensorDesc16,
@ -172,7 +172,7 @@ size_t get_workspace_size16(layer l) {
l.dsrcTensorDesc16, l.dsrcTensorDesc16,
l.bd_algo16, l.bd_algo16,
&s)); &s));
if (s > most) most = s; if (s > most && l.train) most = s;
return most; return most;
} }
#endif #endif
@ -333,12 +333,43 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
#endif #endif
#endif #endif
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation)
void free_convolutional_batchnorm(convolutional_layer *l)
{
if (!l->share_layer) {
free(l->scales);
free(l->scale_updates);
free(l->mean);
free(l->variance);
free(l->mean_delta);
free(l->variance_delta);
free(l->rolling_mean);
free(l->rolling_variance);
free(l->x);
free(l->x_norm);
#ifdef GPU
cuda_free(l->scales_gpu);
cuda_free(l->scale_updates_gpu);
cuda_free(l->mean_gpu);
cuda_free(l->variance_gpu);
cuda_free(l->mean_delta_gpu);
cuda_free(l->variance_delta_gpu);
cuda_free(l->rolling_mean_gpu);
cuda_free(l->rolling_variance_gpu);
cuda_free(l->x_gpu);
cuda_free(l->x_norm_gpu);
#endif
}
}
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int train)
{ {
int total_batch = batch*steps; int total_batch = batch*steps;
int i; int i;
convolutional_layer l = { (LAYER_TYPE)0 }; convolutional_layer l = { (LAYER_TYPE)0 };
l.type = CONVOLUTIONAL; l.type = CONVOLUTIONAL;
l.train = train;
if (xnor) groups = 1; // disable groups for XNOR-net if (xnor) groups = 1; // disable groups for XNOR-net
if (groups < 1) groups = 1; if (groups < 1) groups = 1;
@ -382,11 +413,13 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
} }
else { else {
l.weights = (float*)calloc(l.nweights, sizeof(float)); l.weights = (float*)calloc(l.nweights, sizeof(float));
l.weight_updates = (float*)calloc(l.nweights, sizeof(float));
l.biases = (float*)calloc(n, sizeof(float)); l.biases = (float*)calloc(n, sizeof(float));
if (train) {
l.weight_updates = (float*)calloc(l.nweights, sizeof(float));
l.bias_updates = (float*)calloc(n, sizeof(float)); l.bias_updates = (float*)calloc(n, sizeof(float));
} }
}
// float scale = 1./sqrt(size*size*c); // float scale = 1./sqrt(size*size*c);
float scale = sqrt(2./(size*size*c/groups)); float scale = sqrt(2./(size*size*c/groups));
@ -401,7 +434,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.activation = activation; l.activation = activation;
l.output = (float*)calloc(total_batch*l.outputs, sizeof(float)); l.output = (float*)calloc(total_batch*l.outputs, sizeof(float));
l.delta = (float*)calloc(total_batch*l.outputs, sizeof(float)); if (train) l.delta = (float*)calloc(total_batch*l.outputs, sizeof(float));
l.forward = forward_convolutional_layer; l.forward = forward_convolutional_layer;
l.backward = backward_convolutional_layer; l.backward = backward_convolutional_layer;
@ -445,24 +478,28 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
} }
else { else {
l.scales = (float*)calloc(n, sizeof(float)); l.scales = (float*)calloc(n, sizeof(float));
l.scale_updates = (float*)calloc(n, sizeof(float));
for (i = 0; i < n; ++i) { for (i = 0; i < n; ++i) {
l.scales[i] = 1; l.scales[i] = 1;
} }
if (train) {
l.scale_updates = (float*)calloc(n, sizeof(float));
l.mean = (float*)calloc(n, sizeof(float)); l.mean = (float*)calloc(n, sizeof(float));
l.variance = (float*)calloc(n, sizeof(float)); l.variance = (float*)calloc(n, sizeof(float));
l.mean_delta = (float*)calloc(n, sizeof(float)); l.mean_delta = (float*)calloc(n, sizeof(float));
l.variance_delta = (float*)calloc(n, sizeof(float)); l.variance_delta = (float*)calloc(n, sizeof(float));
}
l.rolling_mean = (float*)calloc(n, sizeof(float)); l.rolling_mean = (float*)calloc(n, sizeof(float));
l.rolling_variance = (float*)calloc(n, sizeof(float)); l.rolling_variance = (float*)calloc(n, sizeof(float));
} }
if (train) {
l.x = (float*)calloc(total_batch * l.outputs, sizeof(float)); l.x = (float*)calloc(total_batch * l.outputs, sizeof(float));
l.x_norm = (float*)calloc(total_batch * l.outputs, sizeof(float)); l.x_norm = (float*)calloc(total_batch * l.outputs, sizeof(float));
} }
}
if(adam){ if(adam){
l.adam = 1; l.adam = 1;
l.m = (float*)calloc(l.nweights, sizeof(float)); l.m = (float*)calloc(l.nweights, sizeof(float));
@ -501,17 +538,17 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
} }
else { else {
l.weights_gpu = cuda_make_array(l.weights, l.nweights); l.weights_gpu = cuda_make_array(l.weights, l.nweights);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights); if (train) l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
#ifdef CUDNN_HALF #ifdef CUDNN_HALF
l.weights_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); l.weights_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1);
l.weight_updates_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); if (train) l.weight_updates_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1);
#endif // CUDNN_HALF #endif // CUDNN_HALF
l.biases_gpu = cuda_make_array(l.biases, n); l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); if (train) l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
} }
l.output_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); l.output_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n); if (train) l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n);
if(binary){ if(binary){
l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights); l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights);
@ -535,20 +572,26 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
} }
else { else {
l.scales_gpu = cuda_make_array(l.scales, n); l.scales_gpu = cuda_make_array(l.scales, n);
if (train) {
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
l.mean_gpu = cuda_make_array(l.mean, n); l.mean_gpu = cuda_make_array(l.mean, n);
l.variance_gpu = cuda_make_array(l.variance, n); l.variance_gpu = cuda_make_array(l.variance, n);
l.rolling_mean_gpu = cuda_make_array(l.mean, n);
l.rolling_variance_gpu = cuda_make_array(l.variance, n);
l.mean_delta_gpu = cuda_make_array(l.mean, n); l.mean_delta_gpu = cuda_make_array(l.mean, n);
l.variance_delta_gpu = cuda_make_array(l.variance, n); l.variance_delta_gpu = cuda_make_array(l.variance, n);
} }
l.rolling_mean_gpu = cuda_make_array(l.mean, n);
l.rolling_variance_gpu = cuda_make_array(l.variance, n);
}
if (train) {
l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
} }
}
if (l.assisted_excitation) if (l.assisted_excitation)
{ {
@ -594,7 +637,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
blur_size = 2; blur_size = 2;
blur_pad = 0; blur_pad = 0;
} }
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0); *(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0, train);
const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size; const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i; int i;
if (blur_size == 2) { if (blur_size == 2) {
@ -649,7 +692,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer() void test_convolutional_layer()
{ {
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0); convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0, 0, NULL, 0, 0);
l.batch_normalize = 1; l.batch_normalize = 1;
float data[] = {1,1,1,1,1, float data[] = {1,1,1,1,1,
1,1,1,1,1, 1,1,1,1,1,
@ -688,11 +731,14 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
l->inputs = l->w * l->h * l->c; l->inputs = l->w * l->h * l->c;
l->output = (float*)realloc(l->output, total_batch * l->outputs * sizeof(float)); l->output = (float*)realloc(l->output, total_batch * l->outputs * sizeof(float));
if (l->train) {
l->delta = (float*)realloc(l->delta, total_batch * l->outputs * sizeof(float)); l->delta = (float*)realloc(l->delta, total_batch * l->outputs * sizeof(float));
if (l->batch_normalize) { if (l->batch_normalize) {
l->x = (float*)realloc(l->x, total_batch * l->outputs * sizeof(float)); l->x = (float*)realloc(l->x, total_batch * l->outputs * sizeof(float));
l->x_norm = (float*)realloc(l->x_norm, total_batch * l->outputs * sizeof(float)); l->x_norm = (float*)realloc(l->x_norm, total_batch * l->outputs * sizeof(float));
} }
}
if (l->xnor) { if (l->xnor) {
//l->binary_input = realloc(l->inputs*l->batch, sizeof(float)); //l->binary_input = realloc(l->inputs*l->batch, sizeof(float));
@ -700,10 +746,12 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
#ifdef GPU #ifdef GPU
if (old_w < w || old_h < h) { if (old_w < w || old_h < h) {
if (l->train) {
cuda_free(l->delta_gpu); cuda_free(l->delta_gpu);
cuda_free(l->output_gpu);
l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs); l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs);
}
cuda_free(l->output_gpu);
l->output_gpu = cuda_make_array(l->output, total_batch*l->outputs); l->output_gpu = cuda_make_array(l->output, total_batch*l->outputs);
if (l->batch_normalize) { if (l->batch_normalize) {
@ -1246,7 +1294,7 @@ void assisted_excitation_forward(convolutional_layer l, network_state state)
} }
} }
if(1) // visualize ground truth if(0) // visualize ground truth
{ {
#ifdef OPENCV #ifdef OPENCV
for (b = 0; b < l.batch; ++b) for (b = 0; b < l.batch; ++b)

@ -28,9 +28,10 @@ void create_convolutional_cudnn_tensors(layer *l);
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif #endif
#endif #endif
void free_convolutional_batchnorm(convolutional_layer *l);
size_t get_convolutional_workspace_size(layer l); size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation); convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int train);
void denormalize_convolutional_layer(convolutional_layer l); void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state); void forward_convolutional_layer(const convolutional_layer layer, network_state state);

@ -26,11 +26,12 @@ static void increment_layer(layer *l, int steps)
#endif #endif
} }
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor) layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor, int train)
{ {
fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters); fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters);
batch = batch / steps; batch = batch / steps;
layer l = { (LAYER_TYPE)0 }; layer l = { (LAYER_TYPE)0 };
l.train = train;
l.batch = batch; l.batch = batch;
l.type = CRNN; l.type = CRNN;
l.steps = steps; l.steps = steps;
@ -50,17 +51,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float)); l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float));
l.input_layer = (layer*)calloc(1, sizeof(layer)); l.input_layer = (layer*)calloc(1, sizeof(layer));
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.input_layer->batch = batch; l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = (layer*)calloc(1, sizeof(layer)); l.self_layer = (layer*)calloc(1, sizeof(layer));
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.self_layer->batch = batch; l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = (layer*)calloc(1, sizeof(layer)); l.output_layer = (layer*)calloc(1, sizeof(layer));
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0); *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0, 0, NULL, 0, train);
l.output_layer->batch = batch; l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;

@ -9,7 +9,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor); layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor, int train);
void resize_crnn_layer(layer *l, int w, int h); void resize_crnn_layer(layer *l, int w, int h);
void free_state_crnn(layer l); void free_state_crnn(layer l);

@ -1949,7 +1949,7 @@ void forward_maxpool_layer_avx(float *src, float *dst, int *indexes, int size, i
} }
} }
dst[out_index] = max; dst[out_index] = max;
indexes[out_index] = max_i; if (indexes) indexes[out_index] = max_i;
} }
} }
} }
@ -2452,7 +2452,7 @@ void forward_maxpool_layer_avx(float *src, float *dst, int *indexes, int size, i
} }
} }
dst[out_index] = max; dst[out_index] = max;
indexes[out_index] = max_i; if (indexes) indexes[out_index] = max_i;
} }
} }
} }

@ -46,10 +46,11 @@ void cudnn_maxpool_setup(layer *l)
} }
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing) maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int train)
{ {
maxpool_layer l = { (LAYER_TYPE)0 }; maxpool_layer l = { (LAYER_TYPE)0 };
l.type = MAXPOOL; l.type = MAXPOOL;
l.train = train;
const int blur_stride_x = stride_x; const int blur_stride_x = stride_x;
const int blur_stride_y = stride_y; const int blur_stride_y = stride_y;
@ -82,18 +83,22 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
l.stride_x = stride_x; l.stride_x = stride_x;
l.stride_y = stride_y; l.stride_y = stride_y;
int output_size = l.out_h * l.out_w * l.out_c * batch; int output_size = l.out_h * l.out_w * l.out_c * batch;
if (train) {
l.indexes = (int*)calloc(output_size, sizeof(int)); l.indexes = (int*)calloc(output_size, sizeof(int));
l.output = (float*)calloc(output_size, sizeof(float));
l.delta = (float*)calloc(output_size, sizeof(float)); l.delta = (float*)calloc(output_size, sizeof(float));
}
l.output = (float*)calloc(output_size, sizeof(float));
l.forward = forward_maxpool_layer; l.forward = forward_maxpool_layer;
l.backward = backward_maxpool_layer; l.backward = backward_maxpool_layer;
#ifdef GPU #ifdef GPU
l.forward_gpu = forward_maxpool_layer_gpu; l.forward_gpu = forward_maxpool_layer_gpu;
l.backward_gpu = backward_maxpool_layer_gpu; l.backward_gpu = backward_maxpool_layer_gpu;
if (train) {
l.indexes_gpu = cuda_make_int_array(output_size); l.indexes_gpu = cuda_make_int_array(output_size);
l.output_gpu = cuda_make_array(l.output, output_size);
l.delta_gpu = cuda_make_array(l.delta, output_size); l.delta_gpu = cuda_make_array(l.delta, output_size);
}
l.output_gpu = cuda_make_array(l.output, output_size);
cudnn_maxpool_setup(&l); cudnn_maxpool_setup(&l);
#endif // GPU #endif // GPU
@ -114,7 +119,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
blur_size = 2; blur_size = 2;
blur_pad = 0; blur_pad = 0;
} }
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0); *(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0, train);
const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size; const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i; int i;
if (blur_size == 2) { if (blur_size == 2) {
@ -163,17 +168,22 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
l->outputs = l->out_w * l->out_h * l->out_c; l->outputs = l->out_w * l->out_h * l->out_c;
int output_size = l->outputs * l->batch; int output_size = l->outputs * l->batch;
if (l->train) {
l->indexes = (int*)realloc(l->indexes, output_size * sizeof(int)); l->indexes = (int*)realloc(l->indexes, output_size * sizeof(int));
l->output = (float*)realloc(l->output, output_size * sizeof(float));
l->delta = (float*)realloc(l->delta, output_size * sizeof(float)); l->delta = (float*)realloc(l->delta, output_size * sizeof(float));
}
l->output = (float*)realloc(l->output, output_size * sizeof(float));
#ifdef GPU #ifdef GPU
CHECK_CUDA(cudaFree((float *)l->indexes_gpu));
CHECK_CUDA(cudaFree(l->output_gpu)); CHECK_CUDA(cudaFree(l->output_gpu));
l->output_gpu = cuda_make_array(l->output, output_size);
if (l->train) {
CHECK_CUDA(cudaFree((float *)l->indexes_gpu));
CHECK_CUDA(cudaFree(l->delta_gpu)); CHECK_CUDA(cudaFree(l->delta_gpu));
l->indexes_gpu = cuda_make_int_array(output_size); l->indexes_gpu = cuda_make_int_array(output_size);
l->output_gpu = cuda_make_array(l->output, output_size);
l->delta_gpu = cuda_make_array(l->delta, output_size); l->delta_gpu = cuda_make_array(l->delta, output_size);
}
cudnn_maxpool_setup(l); cudnn_maxpool_setup(l);
#endif #endif
@ -203,7 +213,7 @@ void forward_maxpool_layer(const maxpool_layer l, network_state state)
max = (val > max) ? val : max; max = (val > max) ? val : max;
} }
l.output[out_index] = max; l.output[out_index] = max;
l.indexes[out_index] = max_i; if (l.indexes) l.indexes[out_index] = max_i;
} }
} }
} }
@ -245,7 +255,7 @@ void forward_maxpool_layer(const maxpool_layer l, network_state state)
} }
} }
l.output[out_index] = max; l.output[out_index] = max;
l.indexes[out_index] = max_i; if (l.indexes) l.indexes[out_index] = max_i;
} }
} }
} }

@ -12,7 +12,7 @@ typedef layer maxpool_layer;
extern "C" { extern "C" {
#endif #endif
image get_maxpool_image(maxpool_layer l); image get_maxpool_image(maxpool_layer l);
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing); maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int train);
void resize_maxpool_layer(maxpool_layer *l, int w, int h); void resize_maxpool_layer(maxpool_layer *l, int w, int h);
void forward_maxpool_layer(const maxpool_layer l, network_state state); void forward_maxpool_layer(const maxpool_layer l, network_state state);
void backward_maxpool_layer(const maxpool_layer l, network_state state); void backward_maxpool_layer(const maxpool_layer l, network_state state);

@ -36,7 +36,7 @@ __global__ void forward_maxpool_depth_layer_kernel(int n, int w, int h, int c, i
max = (val > max) ? val : max; max = (val > max) ? val : max;
} }
output[out_index] = max; output[out_index] = max;
indexes[out_index] = max_i; if (indexes) indexes[out_index] = max_i;
} }
} }
@ -88,7 +88,7 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c
} }
} }
output[out_index] = max; output[out_index] = max;
indexes[out_index] = max_i; if (indexes) indexes[out_index] = max_i;
} }
__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride_x, int stride_y, int size, int pad, float *delta, float *prev_delta, int *indexes) __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride_x, int stride_y, int size, int pad, float *delta, float *prev_delta, int *indexes)

@ -1071,6 +1071,7 @@ void fuse_conv_batchnorm(network net)
} }
} }
free_convolutional_batchnorm(l);
l->batch_normalize = 0; l->batch_normalize = 0;
#ifdef GPU #ifdef GPU
if (gpu_index >= 0) { if (gpu_index >= 0) {

@ -130,6 +130,7 @@ typedef struct size_params{
int c; int c;
int index; int index;
int time_steps; int time_steps;
int train;
network net; network net;
} size_params; } size_params;
@ -199,7 +200,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation); convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation, params.train);
layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0); layer.dot = option_find_float_quiet(options, "dot", 0);
@ -230,7 +231,7 @@ layer parse_crnn(list *options, size_params params)
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, xnor); layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, xnor, params.train);
l.shortcut = option_find_int_quiet(options, "shortcut", 0); l.shortcut = option_find_int_quiet(options, "shortcut", 0);
@ -291,7 +292,7 @@ layer parse_conv_lstm(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0); int xnor = option_find_int_quiet(options, "xnor", 0);
int peephole = option_find_int_quiet(options, "peephole", 0); int peephole = option_find_int_quiet(options, "peephole", 0);
layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, peephole, xnor); layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, peephole, xnor, params.train);
l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32); l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32);
l.shortcut = option_find_int_quiet(options, "shortcut", 0); l.shortcut = option_find_int_quiet(options, "shortcut", 0);
@ -630,7 +631,7 @@ maxpool_layer parse_maxpool(list *options, size_params params)
batch=params.batch; batch=params.batch;
if(!(h && w && c)) error("Layer before maxpool layer must output image."); if(!(h && w && c)) error("Layer before maxpool layer must output image.");
maxpool_layer layer = make_maxpool_layer(batch, h, w, c, size, stride_x, stride_y, padding, maxpool_depth, out_channels, antialiasing); maxpool_layer layer = make_maxpool_layer(batch, h, w, c, size, stride_x, stride_y, padding, maxpool_depth, out_channels, antialiasing, params.train);
return layer; return layer;
} }
@ -684,7 +685,7 @@ layer parse_shortcut(list *options, size_params params, network net)
layer from = net.layers[index]; layer from = net.layers[index];
if (from.antialiasing) from = *from.input_layer; if (from.antialiasing) from = *from.input_layer;
layer s = make_shortcut_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c, assisted_excitation); layer s = make_shortcut_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c, assisted_excitation, params.train);
char *activation_s = option_find_str(options, "activation", "linear"); char *activation_s = option_find_str(options, "activation", "linear");
ACTIVATION activation = get_activation(activation_s); ACTIVATION activation = get_activation(activation_s);
@ -944,6 +945,9 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
net.gpu_index = gpu_index; net.gpu_index = gpu_index;
size_params params; size_params params;
if (batch > 0) params.train = 0; // allocates memory for Detection only
else params.train = 1; // allocates memory for Detection & Training
section *s = (section *)n->val; section *s = (section *)n->val;
list *options = s->options; list *options = s->options;
if(!is_network(s)) error("First section must be [net] or [network]"); if(!is_network(s)) error("First section must be [net] or [network]");

@ -5,11 +5,12 @@
#include <stdio.h> #include <stdio.h>
#include <assert.h> #include <assert.h>
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation) layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, int train)
{ {
if(assisted_excitation) fprintf(stderr, "Shortcut Layer - AE: %d\n", index); if(assisted_excitation) fprintf(stderr, "Shortcut Layer - AE: %d\n", index);
else fprintf(stderr,"Shortcut Layer: %d\n", index); else fprintf(stderr,"Shortcut Layer: %d\n", index);
layer l = { (LAYER_TYPE)0 }; layer l = { (LAYER_TYPE)0 };
l.train = train;
l.type = SHORTCUT; l.type = SHORTCUT;
l.batch = batch; l.batch = batch;
l.w = w2; l.w = w2;
@ -27,7 +28,7 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int
l.index = index; l.index = index;
l.delta = (float*)calloc(l.outputs * batch, sizeof(float)); if (train) l.delta = (float*)calloc(l.outputs * batch, sizeof(float));
l.output = (float*)calloc(l.outputs * batch, sizeof(float)); l.output = (float*)calloc(l.outputs * batch, sizeof(float));
l.forward = forward_shortcut_layer; l.forward = forward_shortcut_layer;
@ -36,7 +37,7 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int
l.forward_gpu = forward_shortcut_layer_gpu; l.forward_gpu = forward_shortcut_layer_gpu;
l.backward_gpu = backward_shortcut_layer_gpu; l.backward_gpu = backward_shortcut_layer_gpu;
l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch); if (train) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch); l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
if (l.assisted_excitation) if (l.assisted_excitation)
{ {
@ -56,14 +57,17 @@ void resize_shortcut_layer(layer *l, int w, int h)
l->h = l->out_h = h; l->h = l->out_h = h;
l->outputs = w*h*l->out_c; l->outputs = w*h*l->out_c;
l->inputs = l->outputs; l->inputs = l->outputs;
l->delta = (float*)realloc(l->delta, l->outputs * l->batch * sizeof(float)); if (l->train) l->delta = (float*)realloc(l->delta, l->outputs * l->batch * sizeof(float));
l->output = (float*)realloc(l->output, l->outputs * l->batch * sizeof(float)); l->output = (float*)realloc(l->output, l->outputs * l->batch * sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free(l->output_gpu); cuda_free(l->output_gpu);
cuda_free(l->delta_gpu);
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch); l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
if (l->train) {
cuda_free(l->delta_gpu);
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch); l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
}
#endif #endif
} }

@ -7,7 +7,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation); layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, int train);
void forward_shortcut_layer(const layer l, network_state state); void forward_shortcut_layer(const layer l, network_state state);
void backward_shortcut_layer(const layer l, network_state state); void backward_shortcut_layer(const layer l, network_state state);
void resize_shortcut_layer(layer *l, int w, int h); void resize_shortcut_layer(layer *l, int w, int h);

@ -282,7 +282,6 @@ void forward_yolo_layer(const layer l, network_state state)
box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h); box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h);
float best_iou = 0; float best_iou = 0;
int best_t = 0; int best_t = 0;
int class_id_match = 0;
for (t = 0; t < l.max_boxes; ++t) { for (t = 0; t < l.max_boxes; ++t) {
box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1); box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
int class_id = state.truth[t*(4 + 1) + b*l.truths + 4]; int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
@ -298,6 +297,7 @@ void forward_yolo_layer(const layer l, network_state state)
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
float objectness = l.output[obj_index]; float objectness = l.output[obj_index];
int pred_class_id = get_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness); int pred_class_id = get_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness);
int class_id_match = 0;
if (class_id == pred_class_id) class_id_match = 1; if (class_id == pred_class_id) class_id_match = 1;
else class_id_match = 0; else class_id_match = 0;

Loading…
Cancel
Save