diff --git a/src/blas.c b/src/blas.c index ccf05223..70e06991 100644 --- a/src/blas.c +++ b/src/blas.c @@ -244,6 +244,28 @@ void l1_cpu(int n, float *pred, float *truth, float *delta, float *error) } } +void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error) +{ + int i; + for(i = 0; i < n; ++i){ + float t = truth[i]; + float p = pred[i]; + error[i] = (t) ? -log(p) : 0; + delta[i] = t-p; + } +} + +void logistic_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error) +{ + int i; + for(i = 0; i < n; ++i){ + float t = truth[i]; + float p = pred[i]; + error[i] = -t*log(p) - (1-t)*log(1-p); + delta[i] = t-p; + } +} + void l2_cpu(int n, float *pred, float *truth, float *delta, float *error) { int i; diff --git a/src/blas.h b/src/blas.h index c40422ac..d5f67250 100644 --- a/src/blas.h +++ b/src/blas.h @@ -37,9 +37,12 @@ void weighted_sum_cpu(float *a, float *b, float *s, int num, float *c); void softmax(float *input, int n, float temp, float *output, int stride); void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out); +void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output); +void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error); #ifdef GPU #include "cuda.h" +#include "tree.h" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); @@ -47,6 +50,7 @@ void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY); void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void scal_ongpu(int N, float ALPHA, float * X, int INCX); void supp_ongpu(int N, float ALPHA, float * X, int INCX); +void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val); void mask_ongpu(int N, float * X, float mask_num, float * mask); void const_ongpu(int N, float ALPHA, float *X, int INCX); void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); @@ -71,6 +75,7 @@ void scale_bias_gpu(float *output, float *biases, int batch, int n, int size); void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); +void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error); void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error); void l2_gpu(int n, float *pred, float *truth, float *delta, float *error); void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc); @@ -79,6 +84,7 @@ void mult_add_into_gpu(int num, float *a, float *b, float *c); void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out); +void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output); void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output); void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t); void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t); @@ -87,5 +93,7 @@ void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, fl void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out); +void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier); + #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 34c0008b..657be1de 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -7,6 +7,7 @@ extern "C" { #include "blas.h" #include "cuda.h" #include "utils.h" +#include "tree.h" } __global__ void scale_bias_kernel(float *output, float *biases, int n, int size) @@ -419,7 +420,13 @@ __global__ void fill_kernel(int N, float ALPHA, float *X, int INCX) if(i < N) X[i*INCX] = ALPHA; } -__global__ void mask_kernel(int n, float *x, float mask_num, float *mask) +__global__ void mask_kernel_new_api(int n, float *x, float mask_num, float *mask, float val) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i < n && mask[i] == mask_num) x[i] = val; +} + +__global__ void mask_kernel(int n, float *x, float mask_num, float *mask) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if(i < n && mask[i] == mask_num) x[i] = mask_num; @@ -592,6 +599,12 @@ extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride check_error(cudaPeekAtLastError()); } +extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val) +{ + mask_kernel_new_api <<>>(N, X, mask_num, mask, val); + check_error(cudaPeekAtLastError()); +} + extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask) { mask_kernel<<>>(N, X, mask_num, mask); @@ -687,6 +700,23 @@ extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, fl check_error(cudaPeekAtLastError()); } +__global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *delta, float *error) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i < n) { + float t = truth[i]; + float p = pred[i]; + error[i] = (t) ? -log(p) : 0; + delta[i] = t - p; + } +} + +extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error) +{ + softmax_x_ent_kernel << > >(n, pred, truth, delta, error); + check_error(cudaPeekAtLastError()); +} + __global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -784,6 +814,40 @@ extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float t check_error(cudaPeekAtLastError()); } +__device__ void softmax_device_new_api(float *input, int n, float temp, int stride, float *output) +{ + int i; + float sum = 0; + float largest = -INFINITY; + for (i = 0; i < n; ++i) { + int val = input[i*stride]; + largest = (val>largest) ? val : largest; + } + for (i = 0; i < n; ++i) { + float e = expf(input[i*stride] / temp - largest / temp); + sum += e; + output[i*stride] = e; + } + for (i = 0; i < n; ++i) { + output[i*stride] /= sum; + } +} + +__global__ void softmax_kernel_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output) +{ + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= batch*groups) return; + int b = id / groups; + int g = id % groups; + softmax_device_new_api(input + b*batch_offset + g*group_offset, n, temp, stride, output + b*batch_offset + g*group_offset); +} + +extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output) +{ + softmax_kernel_new_api << > >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output); + check_error(cudaPeekAtLastError()); +} + __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out) { @@ -814,4 +878,36 @@ extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stri size_t size = w*h*c*batch*stride*stride; upsample_kernel << > >(size, in, w, h, c, batch, stride, forward, scale, out); check_error(cudaPeekAtLastError()); +} + +__global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset) +{ + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= spatial*batch*groups) return; + int s = id % spatial; + id = id / spatial; + int g = id % groups; + int b = id / groups; + int goff = group_offset[g] * spatial; + int boff = b*stride; + softmax_device_new_api(input + goff + boff + s, group_size[g], temp, spatial, output + goff + boff + s); +} + +extern "C" void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier) +{ + int *tree_groups_size = cuda_make_int_array_new_api(hier.group_size, hier.groups); + int *tree_groups_offset = cuda_make_int_array_new_api(hier.group_offset, hier.groups); + /* + static int *tree_groups_size = 0; + static int *tree_groups_offset = 0; + if(!tree_groups_size){ + tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); + tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); + } + */ + int num = spatial*batch*hier.groups; + softmax_tree_kernel <<>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset); + check_error(cudaPeekAtLastError()); + cuda_free((float *)tree_groups_size); + cuda_free((float *)tree_groups_offset); } \ No newline at end of file diff --git a/src/cuda.c b/src/cuda.c index 2284dad0..68fb7f8c 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -162,6 +162,20 @@ int *cuda_make_int_array(size_t n) return x_gpu; } +int *cuda_make_int_array_new_api(int *x, size_t n) +{ + int *x_gpu; + size_t size = sizeof(int)*n; + cudaError_t status = cudaMalloc((void **)&x_gpu, size); + check_error(status); + if (x) { + status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice); + check_error(status); + } + if (!x_gpu) error("Cuda malloc failed\n"); + return x_gpu; +} + void cuda_free(float *x_gpu) { //cudaStreamSynchronize(get_cuda_stream()); diff --git a/src/cuda.h b/src/cuda.h index 50bb82b4..289ee5b4 100644 --- a/src/cuda.h +++ b/src/cuda.h @@ -40,6 +40,7 @@ extern "C" { cublasHandle_t blas_handle(); float *cuda_make_array(float *x, size_t n); int *cuda_make_int_array(size_t n); + int *cuda_make_int_array_new_api(int *x, size_t n); void cuda_push_array(float *x_gpu, float *x, size_t n); YOLODLL_API void cuda_pull_array(float *x_gpu, float *x, size_t n); YOLODLL_API void cuda_set_device(int n); diff --git a/src/layer.h b/src/layer.h index 8b4cff70..a4fd312e 100644 --- a/src/layer.h +++ b/src/layer.h @@ -83,6 +83,7 @@ struct layer{ int side; int stride; int reverse; + int spatial; int pad; int sqrt; int flip; @@ -100,6 +101,7 @@ struct layer{ float shift; float ratio; int focal_loss; + int noloss; int softmax; int classes; int coords; @@ -198,6 +200,7 @@ struct layer{ int * input_sizes; float * delta; float * output; + float * loss; float * squared; float * norms; @@ -289,6 +292,7 @@ struct layer{ float * scale_updates_gpu; float * output_gpu; + float * loss_gpu; float * delta_gpu; float * rand_gpu; float * squared_gpu; diff --git a/src/parser.c b/src/parser.c index c685abee..c82c4a2f 100644 --- a/src/parser.c +++ b/src/parser.c @@ -233,12 +233,17 @@ connected_layer parse_connected(list *options, size_params params) softmax_layer parse_softmax(list *options, size_params params) { - int groups = option_find_int_quiet(options, "groups",1); - softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups); - layer.temperature = option_find_float_quiet(options, "temperature", 1); - char *tree_file = option_find_str(options, "tree", 0); - if (tree_file) layer.softmax_tree = read_tree(tree_file); - return layer; + int groups = option_find_int_quiet(options, "groups", 1); + softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups); + layer.temperature = option_find_float_quiet(options, "temperature", 1); + char *tree_file = option_find_str(options, "tree", 0); + if (tree_file) layer.softmax_tree = read_tree(tree_file); + layer.w = params.w; + layer.h = params.h; + layer.c = params.c; + layer.spatial = option_find_float_quiet(options, "spatial", 0); + layer.noloss = option_find_int_quiet(options, "noloss", 0); + return layer; } int *parse_yolo_mask(char *a, int *num) diff --git a/src/softmax_layer.c b/src/softmax_layer.c index 27f73fdd..bfe34bc1 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -1,12 +1,31 @@ #include "softmax_layer.h" #include "blas.h" #include "cuda.h" +#include "utils.h" +#include "blas.h" + #include #include #include #include #include +#define SECRET_NUM -1234 + +void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output) +{ + int b; + for (b = 0; b < batch; ++b) { + int i; + int count = 0; + for (i = 0; i < hierarchy->groups; ++i) { + int group_size = hierarchy->group_size[i]; + softmax(input + b*inputs + count, group_size, temp, output + b*inputs + count, 1); + count += group_size; + } + } +} + softmax_layer make_softmax_layer(int batch, int inputs, int groups) { assert(inputs%groups == 0); @@ -17,8 +36,10 @@ softmax_layer make_softmax_layer(int batch, int inputs, int groups) l.groups = groups; l.inputs = inputs; l.outputs = inputs; + l.loss = calloc(inputs*batch, sizeof(float)); l.output = calloc(inputs*batch, sizeof(float)); l.delta = calloc(inputs*batch, sizeof(float)); + l.cost = calloc(1, sizeof(float)); l.forward = forward_softmax_layer; l.backward = backward_softmax_layer; @@ -27,45 +48,35 @@ softmax_layer make_softmax_layer(int batch, int inputs, int groups) l.backward_gpu = backward_softmax_layer_gpu; l.output_gpu = cuda_make_array(l.output, inputs*batch); + l.loss_gpu = cuda_make_array(l.loss, inputs*batch); l.delta_gpu = cuda_make_array(l.delta, inputs*batch); #endif return l; } -void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output) +void forward_softmax_layer(const softmax_layer l, network_state net) { - int b; - for(b = 0; b < batch; ++b){ + if(l.softmax_tree){ int i; int count = 0; - for(i = 0; i < hierarchy->groups; ++i){ - int group_size = hierarchy->group_size[i]; - softmax(input+b*inputs + count, group_size, temp, output+b*inputs + count, 1); + for (i = 0; i < l.softmax_tree->groups; ++i) { + int group_size = l.softmax_tree->group_size[i]; + softmax_cpu(net.input + count, group_size, l.batch, l.inputs, 1, 0, 1, l.temperature, l.output + count); count += group_size; } + } else { + softmax_cpu(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output); } -} -void forward_softmax_layer(const softmax_layer l, network_state state) -{ - int b; - int inputs = l.inputs / l.groups; - int batch = l.batch * l.groups; - if(l.softmax_tree){ - softmax_tree(state.input, batch, inputs, l.temperature, l.softmax_tree, l.output); - } else { - for(b = 0; b < batch; ++b){ - softmax(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs, 1); - } + if(net.truth && !l.noloss){ + softmax_x_ent_cpu(l.batch*l.inputs, l.output, net.truth, l.delta, l.loss); + l.cost[0] = sum_array(l.loss, l.batch*l.inputs); } } -void backward_softmax_layer(const softmax_layer l, network_state state) +void backward_softmax_layer(const softmax_layer l, network_state net) { - int i; - for(i = 0; i < l.inputs*l.batch; ++i){ - state.delta[i] += l.delta[i]; - } + axpy_cpu(l.inputs*l.batch, 1, l.delta, 1, net.delta, 1); } #ifdef GPU @@ -75,26 +86,40 @@ void pull_softmax_layer_output(const softmax_layer layer) cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch); } -void forward_softmax_layer_gpu(const softmax_layer l, network_state state) +void forward_softmax_layer_gpu(const softmax_layer l, network_state net) { - int inputs = l.inputs / l.groups; - int batch = l.batch * l.groups; if(l.softmax_tree){ - int i; - int count = 0; - for (i = 0; i < l.softmax_tree->groups; ++i) { - int group_size = l.softmax_tree->group_size[i]; - softmax_gpu(state.input+count, group_size, inputs, batch, l.temperature, l.output_gpu + count); - count += group_size; - } + softmax_tree_gpu(net.input, 1, l.batch, l.inputs, l.temperature, l.output_gpu, *l.softmax_tree); + /* + int i; + int count = 0; + for (i = 0; i < l.softmax_tree->groups; ++i) { + int group_size = l.softmax_tree->group_size[i]; + softmax_gpu(net.input_gpu + count, group_size, l.batch, l.inputs, 1, 0, 1, l.temperature, l.output_gpu + count); + count += group_size; + } + */ } else { - softmax_gpu(state.input, inputs, inputs, batch, l.temperature, l.output_gpu); + if(l.spatial){ + softmax_gpu_new_api(net.input, l.c, l.batch*l.c, l.inputs/l.c, l.w*l.h, 1, l.w*l.h, 1, l.output_gpu); + }else{ + softmax_gpu_new_api(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output_gpu); + } + } + if(net.truth && !l.noloss){ + softmax_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth, l.delta_gpu, l.loss_gpu); + if(l.softmax_tree){ + mask_gpu_new_api(l.batch*l.inputs, l.delta_gpu, SECRET_NUM, net.truth, 0); + mask_gpu_new_api(l.batch*l.inputs, l.loss_gpu, SECRET_NUM, net.truth, 0); + } + cuda_pull_array(l.loss_gpu, l.loss, l.batch*l.inputs); + l.cost[0] = sum_array(l.loss, l.batch*l.inputs); } } -void backward_softmax_layer_gpu(const softmax_layer layer, network_state state) +void backward_softmax_layer_gpu(const softmax_layer layer, network_state net) { - axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1); + axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, net.delta, 1); } #endif