CUDNN_HALF and CC 7.5 by default in darknet.sln

pull/2352/head
AlexeyAB 6 years ago committed by John Aughey
parent c00d3c92db
commit e1bbeb8367
  1. 8
      build/darknet/darknet.vcxproj
  2. 3
      include/darknet.h
  3. 14
      src/convolutional_kernels.cu
  4. 5
      src/network_kernels.cu

@ -89,7 +89,7 @@
<Optimization>Disabled</Optimization>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>C:\opencv_3.0\opencv\build\include;..\..\include;..\..\3rdparty\include;%(AdditionalIncludeDirectories);$(CudaToolkitIncludeDir);$(cudnn)\include</AdditionalIncludeDirectories>
<PreprocessorDefinitions>CUDNN;_CRTDBG_MAP_ALLOC;_MBCS;_TIMESPEC_DEFINED;_SCL_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS;_CRT_RAND_S;GPU;WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>CUDNN_HALF;CUDNN;_CRTDBG_MAP_ALLOC;_MBCS;_TIMESPEC_DEFINED;_SCL_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS;_CRT_RAND_S;GPU;WIN32;DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<UndefinePreprocessorDefinitions>OPENCV;</UndefinePreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<ForcedIncludeFiles>stdlib.h;crtdbg.h;%(ForcedIncludeFiles)</ForcedIncludeFiles>
@ -102,7 +102,7 @@
<AssemblyDebug>true</AssemblyDebug>
</Link>
<CudaCompile>
<CodeGeneration>compute_30,sm_30;compute_52,sm_52</CodeGeneration>
<CodeGeneration>compute_30,sm_30;compute_75,sm_75</CodeGeneration>
<TargetMachinePlatform>64</TargetMachinePlatform>
</CudaCompile>
</ItemDefinitionGroup>
@ -133,7 +133,7 @@
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>C:\opencv_3.0\opencv\build\include;..\..\include;..\..\3rdparty\include;%(AdditionalIncludeDirectories);$(CudaToolkitIncludeDir);$(cudnn)\include</AdditionalIncludeDirectories>
<PreprocessorDefinitions>OPENCV;CUDNN;_TIMESPEC_DEFINED;_SCL_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS;_CRT_RAND_S;GPU;WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>CUDNN_HALF;OPENCV;CUDNN;_TIMESPEC_DEFINED;_SCL_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS;_CRT_RAND_S;GPU;WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<CLanguageStandard>c11</CLanguageStandard>
<CppLanguageStandard>c++1y</CppLanguageStandard>
<PrecompiledHeaderCompileAs>CompileAsCpp</PrecompiledHeaderCompileAs>
@ -152,7 +152,7 @@
</Link>
<CudaCompile>
<TargetMachinePlatform>64</TargetMachinePlatform>
<CodeGeneration>compute_30,sm_30;compute_52,sm_52</CodeGeneration>
<CodeGeneration>compute_30,sm_30;compute_75,sm_75</CodeGeneration>
</CudaCompile>
</ItemDefinitionGroup>
<ItemGroup>

@ -8,8 +8,9 @@
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <pthread.h>
#include <stdint.h>
#include <assert.h>
#include <pthread.h>
#ifdef LIB_EXPORTS
#if defined(_MSC_VER)

@ -139,7 +139,7 @@ __global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
}
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) {
cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
cuda_f32_to_f16 <<< get_number_of_blocks(size, BLOCK), BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
CHECK_CUDA(cudaPeekAtLastError());
}
@ -151,7 +151,7 @@ __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
}
void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) {
cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
cuda_f16_to_f32 <<< get_number_of_blocks(size, BLOCK), BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
CHECK_CUDA(cudaPeekAtLastError());
}
@ -161,6 +161,7 @@ half *cuda_make_f16_from_f32_array(float *src, size_t n)
size_t size = sizeof(half)*n;
CHECK_CUDA(cudaMalloc((void **)&dst16, size));
if (src) {
assert(n > 0);
cuda_convert_f32_to_f16(src, n, (float *)dst16);
}
if (!dst16) error("Cuda malloc failed\n");
@ -434,6 +435,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//printf("\n input16_size: cur = %zu \t max = %zu \n", input16_size, *state.net.max_input16_size);
*state.net.max_input16_size = input16_size;
if (*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
assert(*state.net.max_input16_size > 0);
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
}
float *input16 = *state.net.input16_gpu;
@ -441,10 +443,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
if (*state.net.max_output16_size < output16_size) {
*state.net.max_output16_size = output16_size;
if (*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
assert(*state.net.max_output16_size > 0);
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
}
float *output16 = *state.net.output16_gpu;
assert(input16_size > 0);
cuda_convert_f32_to_f16(state.input, input16_size, input16);
//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
@ -608,6 +612,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if (*state.net.max_input16_size < input16_size) {
*state.net.max_input16_size = input16_size;
if (*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
assert(*state.net.max_input16_size > 0);
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
}
float *input16 = *state.net.input16_gpu;
@ -615,10 +620,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
if (*state.net.max_output16_size < delta16_size) {
*state.net.max_output16_size = delta16_size;
if (*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
assert(*state.net.max_output16_size > 0);
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
}
float *delta16 = *state.net.output16_gpu;
assert(input16_size > 0);
assert(delta16_size > 0);
cuda_convert_f32_to_f16(state.input, input16_size, input16);
cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, delta16);
@ -664,6 +672,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
// calculate conv weight updates
// Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum
// so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
assert((l.c*l.n*l.size*l.size) > 0);
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(),
@ -815,6 +824,7 @@ void push_convolutional_layer(convolutional_layer layer)
{
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
#ifdef CUDNN_HALF
assert((layer.c*layer.n*layer.size*layer.size) > 0);
cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, layer.weights_gpu16);
#endif
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);

@ -152,7 +152,10 @@ void forward_backward_network_gpu(network net, float *x, float *y)
int i;
for (i = 0; i < net.n; ++i) {
layer l = net.layers[i];
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
if (l.weights_gpu) {
assert((l.c*l.n*l.size*l.size) > 0);
cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
}
}
#endif
forward_network_gpu(net, state);

Loading…
Cancel
Save