You can train with mini_batch=128 (batch=256 subdivisions=2) on GPU 8 GB VRAM + 128 GB CPU-RAM and more. Set in cfg-file: optimized_memory=3 and workspace_size_limit_MB=2000 or 4000

pull/6241/head
AlexeyAB 6 years ago
parent b832c727dc
commit 0d98f20be1
  1. 8
      include/darknet.h
  2. 9
      src/batchnorm_layer.c
  3. 49
      src/convolutional_layer.c
  4. 1
      src/convolutional_layer.h
  5. 137
      src/dark_cuda.c
  6. 7
      src/dark_cuda.h
  7. 2
      src/detector.c
  8. 10
      src/layer.c
  9. 3
      src/network.c
  10. 9
      src/network_kernels.cu
  11. 98
      src/parser.c

@ -239,6 +239,8 @@ struct layer {
int xnor;
int peephole;
int use_bin_output;
int keep_delta_gpu;
int optimized_memory;
int steps;
int state_constrain;
int hidden;
@ -677,7 +679,13 @@ typedef struct network {
size_t *max_input16_size;
size_t *max_output16_size;
int wait_stream;
float *global_delta_gpu;
float *state_delta_gpu;
size_t max_delta_gpu_size;
#endif
int optimized_memory;
size_t workspace_size_limit;
} network;
// network.h

@ -184,7 +184,6 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
if (state.train) {
simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_gpu);
//copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
#ifdef CUDNN
float one = 1;
float zero = 0;
@ -259,16 +258,16 @@ void backward_batchnorm_layer_gpu(layer l, network_state state)
l.normDstTensorDesc,
l.delta_gpu, // input
l.normDstTensorDesc,
l.x_norm_gpu, // output
l.output_gpu, //l.x_norm_gpu, // output
l.normTensorDesc,
l.scales_gpu, // output (should be FP32)
l.scales_gpu, // input (should be FP32)
l.scale_updates_gpu, // output (should be FP32)
l.bias_updates_gpu, // output (should be FP32)
.00001,
l.mean_gpu, // input (should be FP32)
l.variance_gpu); // input (should be FP32)
simple_copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, l.delta_gpu);
//copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1);
simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.delta_gpu);
//simple_copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, l.delta_gpu);
#else
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h);
backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu);

@ -212,7 +212,7 @@ void create_convolutional_cudnn_tensors(layer *l)
CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&l->convDesc));
}
void cudnn_convolutional_setup(layer *l, int cudnn_preference)
void cudnn_convolutional_setup(layer *l, int cudnn_preference, size_t workspace_size_specify)
{
// CUDNN_HALF
@ -291,6 +291,13 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
printf(" CUDNN-slow ");
}
if (cudnn_preference == cudnn_specify)
{
forward_algo = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
backward_algo = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
//printf(" CUDNN-specified %zu ", workspace_size_specify);
}
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
@ -298,7 +305,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
l->convDesc,
l->dstTensorDesc,
(cudnnConvolutionFwdPreference_t)forward_algo,
0,
workspace_size_specify,
&l->fw_algo));
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
l->weightDesc,
@ -306,7 +313,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
l->convDesc,
l->dsrcTensorDesc,
(cudnnConvolutionBwdDataPreference_t)backward_algo,
0,
workspace_size_specify,
&l->bd_algo));
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
l->srcTensorDesc,
@ -314,7 +321,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
l->convDesc,
l->dweightDesc,
(cudnnConvolutionBwdFilterPreference_t)backward_filter,
0,
workspace_size_specify,
&l->bf_algo));
//if (data_type == CUDNN_DATA_HALF)
@ -439,7 +446,9 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.activation = activation;
l.output = (float*)calloc(total_batch*l.outputs, sizeof(float));
#ifndef GPU
if (train) l.delta = (float*)calloc(total_batch*l.outputs, sizeof(float));
#endif // not GPU
l.forward = forward_convolutional_layer;
l.backward = backward_convolutional_layer;
@ -500,10 +509,14 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.rolling_variance = (float*)calloc(n, sizeof(float));
}
#ifndef GPU
if (train) {
l.x = (float*)calloc(total_batch * l.outputs, sizeof(float));
l.x_norm = (float*)calloc(total_batch * l.outputs, sizeof(float));
}
if (l.activation == SWISH || l.activation == MISH) l.activation_input = (float*)calloc(total_batch*l.outputs, sizeof(float));
#endif // not GPU
}
if(adam){
l.adam = 1;
@ -515,10 +528,11 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.scale_v = (float*)calloc(n, sizeof(float));
}
if (l.activation == SWISH || l.activation == MISH) l.activation_input = (float*)calloc(total_batch*l.outputs, sizeof(float));
#ifdef GPU
if (l.activation == SWISH || l.activation == MISH) l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*l.outputs);
if (l.activation == SWISH || l.activation == MISH) {
l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*l.outputs);
}
l.forward_gpu = forward_convolutional_layer_gpu;
l.backward_gpu = backward_convolutional_layer_gpu;
@ -583,9 +597,10 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.mean_gpu = cuda_make_array(l.mean, n);
l.variance_gpu = cuda_make_array(l.variance, n);
#ifndef CUDNN
l.mean_delta_gpu = cuda_make_array(l.mean, n);
l.variance_delta_gpu = cuda_make_array(l.variance, n);
#endif // CUDNN
}
l.rolling_mean_gpu = cuda_make_array(l.mean, n);
@ -594,7 +609,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
if (train) {
l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
//l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
}
}
@ -606,7 +621,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
}
#ifdef CUDNN
create_convolutional_cudnn_tensors(&l);
cudnn_convolutional_setup(&l, cudnn_fastest);
cudnn_convolutional_setup(&l, cudnn_fastest, 0);
#endif // CUDNN
}
#endif // GPU
@ -790,7 +805,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
}
}
#ifdef CUDNN
cudnn_convolutional_setup(l, cudnn_fastest);
cudnn_convolutional_setup(l, cudnn_fastest, 0);
#endif
#endif
l->workspace_size = get_convolutional_workspace_size(*l);
@ -802,12 +817,24 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
CHECK_CUDA(cudaMemGetInfo(&free_byte, &total_byte));
if (l->workspace_size > free_byte || l->workspace_size >= total_byte / 2) {
printf(" used slow CUDNN algo without Workspace! Need memory: %zu, available: %zu\n", l->workspace_size, (free_byte < total_byte/2) ? free_byte : total_byte/2);
cudnn_convolutional_setup(l, cudnn_smallest);
cudnn_convolutional_setup(l, cudnn_smallest, 0);
l->workspace_size = get_convolutional_workspace_size(*l);
}
#endif
}
void set_specified_workspace_limit(convolutional_layer *l, size_t workspace_size_limit)
{
#ifdef CUDNN
size_t free_byte;
size_t total_byte;
CHECK_CUDA(cudaMemGetInfo(&free_byte, &total_byte));
cudnn_convolutional_setup(l, cudnn_specify, workspace_size_limit);
l->workspace_size = get_convolutional_workspace_size(*l);
//printf("Set specified workspace limit for cuDNN: %zu, available: %zu, workspace = %zu \n", workspace_size_limit, free_byte, l->workspace_size);
#endif // CUDNN
}
void add_bias(float *output, float *biases, int batch, int n, int size)
{
int i,j,b;

@ -33,6 +33,7 @@ void free_convolutional_batchnorm(convolutional_layer *l);
size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int train);
void denormalize_convolutional_layer(convolutional_layer l);
void set_specified_workspace_limit(convolutional_layer *l, size_t workspace_size_limit);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);

@ -226,16 +226,138 @@ cublasHandle_t blas_handle()
return handle[i];
}
static float **pinned_ptr = NULL;
static size_t pinned_num_of_blocks = 0;
static size_t pinned_index = 0;
static size_t pinned_block_id = 0;
static const size_t pinned_block_size = (size_t)1024 * 1024 * 1024 * 1; // 1 GB block size
static pthread_mutex_t mutex_pinned = PTHREAD_MUTEX_INITIALIZER;
// free CPU-pinned memory
void free_pinned_memory()
{
if (pinned_ptr) {
int k;
for (k = 0; k < pinned_num_of_blocks; ++k) {
cuda_free_host(pinned_ptr[k]);
}
free(pinned_ptr);
pinned_ptr = NULL;
}
}
// custom CPU-pinned memory allocation
void pre_allocate_pinned_memory(const size_t size)
{
const size_t num_of_blocks = size / pinned_block_size + ((size % pinned_block_size) ? 1 : 0);
printf("pre_allocate... pinned_ptr = %p \n", pinned_ptr);
pthread_mutex_lock(&mutex_pinned);
if (!pinned_ptr) {
pinned_ptr = (float *)calloc(num_of_blocks, sizeof(float *));
if(!pinned_ptr) error("calloc failed in pre_allocate() \n");
printf("pre_allocate: size = %Iu MB, num_of_blocks = %Iu, block_size = %Iu MB \n",
size / (1024*1024), num_of_blocks, pinned_block_size / (1024 * 1024));
int k;
for (k = 0; k < num_of_blocks; ++k) {
cudaError_t status = cudaHostAlloc((void **)&pinned_ptr[k], pinned_block_size, cudaHostRegisterMapped);
if (status != cudaSuccess) fprintf(stderr, " Can't pre-allocate CUDA-pinned buffer on CPU-RAM \n");
CHECK_CUDA(status);
if (!pinned_ptr[k]) error("cudaHostAlloc failed\n");
else {
printf(" Allocated %d pinned block \n", pinned_block_size);
}
}
pinned_num_of_blocks = num_of_blocks;
}
pthread_mutex_unlock(&mutex_pinned);
}
// simple - get pre-allocated pinned memory
float *cuda_make_array_pinned_preallocated(float *x, size_t n)
{
pthread_mutex_lock(&mutex_pinned);
float *x_cpu = NULL;
const size_t memory_step = 4096;
const size_t size = sizeof(float)*n;
const size_t allocation_size = ((size / 4096) + 1) * 4096;
if (pinned_ptr && pinned_block_id < pinned_num_of_blocks && (allocation_size < pinned_block_size/2))
{
if ((allocation_size + pinned_index) > pinned_block_size) {
const float filled = (float)100 * pinned_index / pinned_block_size;
printf("\n Pinned block_id = %d, filled = %f %% \n", pinned_block_id, filled);
pinned_block_id++;
pinned_index = 0;
}
if ((allocation_size + pinned_index) < pinned_block_size && pinned_block_id < pinned_num_of_blocks) {
x_cpu = (float *)((char *)pinned_ptr[pinned_block_id] + pinned_index);
pinned_index += allocation_size;
}
else {
//printf("Pre-allocated pinned memory is over! \n");
}
}
if(!x_cpu) {
if (allocation_size > pinned_block_size / 2) {
printf("Try to allocate new pinned memory, size = %d MB \n", size / (1024 * 1024));
cudaError_t status = cudaHostAlloc((void **)&x_cpu, size, cudaHostRegisterMapped);
if (status != cudaSuccess) fprintf(stderr, " Can't allocate CUDA-pinned memory on CPU-RAM (pre-allocated memory is over too) \n");
CHECK_CUDA(status);
}
else {
printf("Try to allocate new pinned BLOCK, size = %d MB \n", size / (1024 * 1024));
pinned_num_of_blocks++;
pinned_block_id = pinned_num_of_blocks - 1;
pinned_index = 0;
pinned_ptr = (float *)realloc(pinned_ptr, pinned_num_of_blocks * sizeof(float *));
cudaError_t status = cudaHostAlloc((void **)&pinned_ptr[pinned_block_id], pinned_block_size, cudaHostRegisterMapped);
if (status != cudaSuccess) fprintf(stderr, " Can't pre-allocate CUDA-pinned buffer on CPU-RAM \n");
CHECK_CUDA(status);
x_cpu = pinned_ptr[pinned_block_id];
}
}
if (x) {
cudaError_t status = cudaMemcpyAsync(x_cpu, x, size, cudaMemcpyDefault, get_cuda_stream());
CHECK_CUDA(status);
}
pthread_mutex_unlock(&mutex_pinned);
return x_cpu;
}
float *cuda_make_array_pinned(float *x, size_t n)
{
float *x_gpu;
size_t size = sizeof(float)*n;
//cudaError_t status = cudaMalloc((void **)&x_gpu, size);
cudaError_t status = cudaHostAlloc((void **)&x_gpu, size, cudaHostRegisterMapped);
if (status != cudaSuccess) fprintf(stderr, " Can't allocate CUDA-pinned memory on CPU-RAM \n");
CHECK_CUDA(status);
if (x) {
status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyDefault, get_cuda_stream());
CHECK_CUDA(status);
}
if (!x_gpu) error("cudaHostAlloc failed\n");
return x_gpu;
}
float *cuda_make_array(float *x, size_t n)
{
float *x_gpu;
size_t size = sizeof(float)*n;
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
//cudaError_t status = cudaMallocManaged((void **)&x_gpu, size, cudaMemAttachGlobal);
//status = cudaMemAdvise(x_gpu, size, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId);
if (status != cudaSuccess) fprintf(stderr, " Try to set subdivisions=64 in your cfg-file. \n");
CHECK_CUDA(status);
if(x){
//status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyHostToDevice, get_cuda_stream());
status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyDefault, get_cuda_stream());
CHECK_CUDA(status);
}
if(!x_gpu) error("Cuda malloc failed\n");
@ -301,11 +423,18 @@ void cuda_free(float *x_gpu)
CHECK_CUDA(status);
}
void cuda_free_host(float *x_cpu)
{
//cudaStreamSynchronize(get_cuda_stream());
cudaError_t status = cudaFreeHost(x_cpu);
CHECK_CUDA(status);
}
void cuda_push_array(float *x_gpu, float *x, size_t n)
{
size_t size = sizeof(float)*n;
//cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
cudaError_t status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyHostToDevice, get_cuda_stream());
cudaError_t status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyDefault, get_cuda_stream());
CHECK_CUDA(status);
}
@ -313,7 +442,7 @@ void cuda_pull_array(float *x_gpu, float *x, size_t n)
{
size_t size = sizeof(float)*n;
//cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDeviceToHost, get_cuda_stream());
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDefault, get_cuda_stream());
CHECK_CUDA(status);
cudaStreamSynchronize(get_cuda_stream());
}
@ -321,7 +450,7 @@ void cuda_pull_array(float *x_gpu, float *x, size_t n)
void cuda_pull_array_async(float *x_gpu, float *x, size_t n)
{
size_t size = sizeof(float)*n;
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDeviceToHost, get_cuda_stream());
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDefault, get_cuda_stream());
check_error(status);
//cudaStreamSynchronize(get_cuda_stream());
}

@ -57,6 +57,10 @@ extern "C" {
#define CHECK_CUDA(X) check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ );
cublasHandle_t blas_handle();
void free_pinned_memory();
void pre_allocate_pinned_memory(size_t size);
float *cuda_make_array_pinned_preallocated(float *x, size_t n);
float *cuda_make_array_pinned(float *x, size_t n);
float *cuda_make_array(float *x, size_t n);
int *cuda_make_int_array(size_t n);
int *cuda_make_int_array_new_api(int *x, size_t n);
@ -64,6 +68,7 @@ extern "C" {
//LIB_API void cuda_pull_array(float *x_gpu, float *x, size_t n);
//LIB_API void cuda_set_device(int n);
int cuda_get_device();
void cuda_free_host(float *x_cpu);
void cuda_free(float *x_gpu);
void cuda_random(float *x_gpu, size_t n);
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
@ -75,7 +80,7 @@ extern "C" {
#ifdef CUDNN
cudnnHandle_t cudnn_handle();
enum {cudnn_fastest, cudnn_smallest};
enum {cudnn_fastest, cudnn_smallest, cudnn_specify};
void cudnn_check_error_extended(cudnnStatus_t status, const char *file, int line, const char *date_time);
#define CHECK_CUDNN(X) cudnn_check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ );

@ -278,7 +278,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//network net_combined = combine_train_valid_networks(net, net_map);
iter_map = i;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
//mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
if (mean_average_precision > best_map) {
best_map = mean_average_precision;

@ -154,7 +154,6 @@ void free_layer(layer l)
if (l.rolling_variance_gpu) cuda_free(l.rolling_variance_gpu), l.rolling_variance_gpu = NULL;
if (l.variance_delta_gpu) cuda_free(l.variance_delta_gpu), l.variance_delta_gpu = NULL;
if (l.mean_delta_gpu) cuda_free(l.mean_delta_gpu), l.mean_delta_gpu = NULL;
if (l.x_gpu) cuda_free(l.x_gpu); // dont free
if (l.x_norm_gpu) cuda_free(l.x_norm_gpu);
// assisted excitation
@ -175,9 +174,12 @@ void free_layer(layer l)
if (l.scales_gpu) cuda_free(l.scales_gpu), l.scales_gpu = NULL;
if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu), l.scale_updates_gpu = NULL;
if (l.input_antialiasing_gpu) cuda_free(l.input_antialiasing_gpu), l.input_antialiasing_gpu = NULL;
if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL;
if (l.activation_input_gpu) cuda_free(l.activation_input_gpu), l.activation_input_gpu = NULL;
if (l.delta_gpu) cuda_free(l.delta_gpu), l.delta_gpu = NULL;
if (l.optimized_memory < 2) {
if (l.x_gpu) cuda_free(l.x_gpu); l.x_gpu = NULL;
if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL;
if (l.activation_input_gpu) cuda_free(l.activation_input_gpu), l.activation_input_gpu = NULL;
}
if (l.delta_gpu && l.keep_delta_gpu && l.optimized_memory < 3) cuda_free(l.delta_gpu), l.delta_gpu = NULL;
if (l.rand_gpu) cuda_free(l.rand_gpu);
if (l.squared_gpu) cuda_free(l.squared_gpu);
if (l.norms_gpu) cuda_free(l.norms_gpu);

@ -544,7 +544,7 @@ int resize_network(network *net, int w, int h)
#ifdef GPU
l.output_gpu = net->layers[i-1].output_gpu;
l.delta_gpu = net->layers[i-1].delta_gpu;
#endif
#endif
}else if (l.type == UPSAMPLE) {
resize_upsample_layer(&l, w, h);
}else if(l.type == REORG){
@ -1035,6 +1035,7 @@ void free_network(network net)
#ifdef GPU
if (gpu_index >= 0) cuda_free(net.workspace);
else free(net.workspace);
free_pinned_memory();
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
if (net.input_pinned_cpu) { // CPU
if (net.input_pinned_cpu_flag) cudaFreeHost(net.input_pinned_cpu);

@ -105,10 +105,19 @@ void backward_network_gpu(network net, network_state state)
layer prev = net.layers[i-1];
state.input = prev.output_gpu;
state.delta = prev.delta_gpu;
if (net.optimized_memory && !prev.keep_delta_gpu) {
state.delta = net.state_delta_gpu;
}
}
if (l.onlyforward) continue;
l.backward_gpu(l, state);
layer prev = net.layers[i - 1];
if (net.optimized_memory && state.delta && !prev.keep_delta_gpu) {
simple_copy_ongpu(prev.outputs*prev.batch, state.delta, prev.delta_gpu);
fill_ongpu(prev.outputs*prev.batch, 0, net.state_delta_gpu, 1);
}
/*
if(i != 0)
{

@ -896,6 +896,9 @@ void parse_net_options(list *options, network *net)
net->batch *= net->time_steps;
net->subdivisions = subdivs;
net->optimized_memory = option_find_int_quiet(options, "optimized_memory", 0);
net->workspace_size_limit = (size_t)1024*1024 * option_find_float_quiet(options, "workspace_size_limit_MB", 1024); // 1024 MB by default
net->adam = option_find_int_quiet(options, "adam", 0);
if(net->adam){
net->B1 = option_find_float(options, "B1", .9);
@ -1015,6 +1018,13 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
if(!is_network(s)) error("First section must be [net] or [network]");
parse_net_options(options, &net);
#ifdef GPU
printf("net.optimized_memory = %d \n", net.optimized_memory);
if (net.optimized_memory >= 2 && params.train) {
pre_allocate_pinned_memory((size_t)1024 * 1024 * 1024 * 8); // pre-allocate 8 GB CPU-RAM for pinned memory
}
#endif // GPU
params.h = net.h;
params.w = net.w;
params.c = net.c;
@ -1066,17 +1076,22 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
l = parse_crop(options, params);
}else if(lt == COST){
l = parse_cost(options, params);
l.keep_delta_gpu = 1;
}else if(lt == REGION){
l = parse_region(options, params);
l.keep_delta_gpu = 1;
}else if (lt == YOLO) {
l = parse_yolo(options, params);
l.keep_delta_gpu = 1;
}else if (lt == GAUSSIAN_YOLO) {
l = parse_gaussian_yolo(options, params);
l.keep_delta_gpu = 1;
}else if(lt == DETECTION){
l = parse_detection(options, params);
}else if(lt == SOFTMAX){
l = parse_softmax(options, params);
net.hierarchy = l.softmax_tree;
l.keep_delta_gpu = 1;
}else if(lt == NORMALIZATION){
l = parse_normalization(options, params);
}else if(lt == BATCHNORM){
@ -1092,22 +1107,28 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
}else if(lt == ROUTE){
l = parse_route(options, params);
int k;
for (k = 0; k < l.n; ++k) net.layers[l.input_layers[k]].use_bin_output = 0;
for (k = 0; k < l.n; ++k) {
net.layers[l.input_layers[k]].use_bin_output = 0;
net.layers[l.input_layers[k]].keep_delta_gpu = 1;
}
}else if (lt == UPSAMPLE) {
l = parse_upsample(options, params, net);
}else if(lt == SHORTCUT){
l = parse_shortcut(options, params, net);
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
net.layers[l.index].keep_delta_gpu = 1;
}else if (lt == SCALE_CHANNELS) {
l = parse_scale_channels(options, params, net);
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
net.layers[l.index].keep_delta_gpu = 1;
}
else if (lt == SAM) {
l = parse_sam(options, params, net);
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
net.layers[l.index].keep_delta_gpu = 1;
}else if(lt == DROPOUT){
l = parse_dropout(options, params);
l.output = net.layers[count-1].output;
@ -1132,6 +1153,42 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
}else{
fprintf(stderr, "Type not recognized: %s\n", s->type);
}
#ifdef GPU
// futher GPU-memory optimization: net.optimized_memory == 2
if (net.optimized_memory >= 2 && params.train)
{
l.optimized_memory = net.optimized_memory;
if (l.output_gpu) {
cuda_free(l.output_gpu);
//l.output_gpu = cuda_make_array_pinned(l.output, l.batch*l.outputs); // l.steps
l.output_gpu = cuda_make_array_pinned_preallocated(NULL, l.batch*l.outputs); // l.steps
}
if (l.activation_input_gpu) {
cuda_free(l.activation_input_gpu);
l.activation_input_gpu = cuda_make_array_pinned_preallocated(NULL, l.batch*l.outputs); // l.steps
}
if (l.x_gpu) {
cuda_free(l.x_gpu);
l.x_gpu = cuda_make_array_pinned_preallocated(NULL, l.batch*l.outputs); // l.steps
}
// maximum optimization
if (net.optimized_memory >= 3) {
if (l.delta_gpu) {
cuda_free(l.delta_gpu);
//l.delta_gpu = cuda_make_array_pinned_preallocated(NULL, l.batch*l.outputs); // l.steps
//printf("\n\n PINNED DELTA GPU = %d \n", l.batch*l.outputs);
}
}
if (l.type == CONVOLUTIONAL) {
set_specified_workspace_limit(&l, net.workspace_size_limit); // workspace size limit 1 GB
}
}
#endif // GPU
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
l.dontload = option_find_int_quiet(options, "dontload", 0);
@ -1162,6 +1219,45 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
if (l.bflops > 0) bflops += l.bflops;
}
free_list(sections);
#ifdef GPU
if (net.optimized_memory && params.train)
{
int k;
for (k = 0; k < net.n; ++k) {
layer l = net.layers[k];
// delta GPU-memory optimization: net.optimized_memory == 1
if (!l.keep_delta_gpu) {
const size_t delta_size = l.outputs*l.batch; // l.steps
if (net.max_delta_gpu_size < delta_size) {
net.max_delta_gpu_size = delta_size;
if (net.global_delta_gpu) cuda_free(net.global_delta_gpu);
if (net.state_delta_gpu) cuda_free(net.state_delta_gpu);
assert(net.max_delta_gpu_size > 0);
net.global_delta_gpu = (float *)cuda_make_array(NULL, net.max_delta_gpu_size);
net.state_delta_gpu = (float *)cuda_make_array(NULL, net.max_delta_gpu_size);
}
if (l.delta_gpu) {
if (net.optimized_memory >= 3) {}
else cuda_free(l.delta_gpu);
}
l.delta_gpu = net.global_delta_gpu;
}
// maximum optimization
if (net.optimized_memory >= 3) {
if (l.delta_gpu && l.keep_delta_gpu) {
//cuda_free(l.delta_gpu); // already called above
l.delta_gpu = cuda_make_array_pinned_preallocated(NULL, l.batch*l.outputs); // l.steps
//printf("\n\n PINNED DELTA GPU = %d \n", l.batch*l.outputs);
}
}
net.layers[k] = l;
}
}
#endif
net.outputs = get_network_output_size(net);
net.output = get_network_output(net);
fprintf(stderr, "Total BFLOPS %5.3f \n", bflops);

Loading…
Cancel
Save