Added scale_x_y param to [yolo]-layer (for sigmoid).

In previous commit is used GIoU(generalized-iou) from  https://github.com/generalized-iou/g-darknet
pull/3723/head
AlexeyAB 6 years ago
parent 6e13527f06
commit eac26226a7
  1. 1
      include/darknet.h
  2. 2
      scripts/README.md
  3. 6
      src/blas.c
  4. 2
      src/blas.h
  5. 12
      src/blas_kernels.cu
  6. 4
      src/detector.c
  7. 3
      src/parser.c
  8. 4
      src/yolo_layer.c

@ -314,6 +314,7 @@ struct layer {
float *weights;
float *weight_updates;
float scale_x_y;
float iou_normalizer;
float cls_normalizer;
int iou_loss;

@ -2,6 +2,8 @@
### Datasets:
BDD100K - Diverse Driving Video: https://bair.berkeley.edu/blog/2018/05/30/bdd/
Pascal VOC: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
MS COCO: http://cocodataset.org/#download

@ -169,6 +169,12 @@ void scal_cpu(int N, float ALPHA, float *X, int INCX)
for(i = 0; i < N; ++i) X[i*INCX] *= ALPHA;
}
void scal_add_cpu(int N, float ALPHA, float BETA, float *X, int INCX)
{
int i;
for (i = 0; i < N; ++i) X[i*INCX] = X[i*INCX] * ALPHA + BETA;
}
void fill_cpu(int N, float ALPHA, float *X, int INCX)
{
int i;

@ -27,6 +27,7 @@ void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
void scal_cpu(int N, float ALPHA, float *X, int INCX);
void scal_add_cpu(int N, float ALPHA, float BETA, float *X, int INCX);
void fill_cpu(int N, float ALPHA, float * X, int INCX);
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
void test_gpu_blas();
@ -61,6 +62,7 @@ void simple_copy_ongpu(int size, float *src, float *dst);
void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY);
void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
void scal_ongpu(int N, float ALPHA, float * X, int INCX);
void scal_add_ongpu(int N, float ALPHA, float BETA, float * X, int INCX);
void supp_ongpu(int N, float ALPHA, float * X, int INCX);
void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val);
void mask_ongpu(int N, float * X, float mask_num, float * mask);

@ -414,6 +414,12 @@ __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
if(i < N) X[i*INCX] *= ALPHA;
}
__global__ void scal_add_kernel(int N, float ALPHA, float BETA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < N) X[i*INCX] = X[i*INCX] * ALPHA + BETA;
}
__global__ void fill_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -644,6 +650,12 @@ extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void scal_add_ongpu(int N, float ALPHA, float BETA, float * X, int INCX)
{
scal_add_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, ALPHA, BETA, X, INCX);
CHECK_CUDA(cudaPeekAtLastError());
}
extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
{
supp_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);

@ -1046,10 +1046,10 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh);
const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh));
const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);
printf("\n for thresh = %1.2f, precision = %1.2f, recall = %1.2f, F1-score = %1.2f \n",
printf("\n for conf_thresh = %1.2f, precision = %1.2f, recall = %1.2f, F1-score = %1.2f \n",
thresh_calc_avg_iou, cur_precision, cur_recall, f1_score);
printf(" for thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n",
printf(" for conf_thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n",
thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100);
mean_average_precision = mean_average_precision / classes;

@ -333,6 +333,7 @@ layer parse_yolo(list *options, size_params params)
}
//assert(l.outputs == params.inputs);
l.scale_x_y = option_find_float_quiet(options, "scale_x_y", 1);
l.iou_normalizer = option_find_float_quiet(options, "iou_normalizer", 0.75);
l.cls_normalizer = option_find_float_quiet(options, "cls_normalizer", 1);
char *iou_loss = option_find_str_quiet(options, "iou_loss", "mse"); // "iou");
@ -340,7 +341,7 @@ layer parse_yolo(list *options, size_params params)
if (strcmp(iou_loss, "mse") == 0) l.iou_loss = MSE;
else if (strcmp(iou_loss, "giou") == 0) l.iou_loss = GIOU;
else l.iou_loss = IOU;
fprintf(stderr, "Yolo layer params: iou loss: %s, iou_normalizer: %f, cls_normalizer: %f\n", (l.iou_loss == MSE ? "mse" : (l.iou_loss == GIOU ? "giou" : "iou")), l.iou_normalizer, l.cls_normalizer);
fprintf(stderr, "[yolo] params: iou loss: %s, iou_norm: %2.2f, cls_norm: %2.2f, scale_x_y: %2.2f\n", (l.iou_loss == MSE ? "mse" : (l.iou_loss == GIOU ? "giou" : "iou")), l.iou_normalizer, l.cls_normalizer, l.scale_x_y);
l.jitter = option_find_float(options, "jitter", .2);
l.focal_loss = option_find_int_quiet(options, "focal_loss", 0);

@ -259,7 +259,8 @@ void forward_yolo_layer(const layer l, network_state state)
for (b = 0; b < l.batch; ++b) {
for (n = 0; n < l.n; ++n) {
int index = entry_index(l, b, n*l.w*l.h, 0);
activate_array(l.output + index, 2 * l.w*l.h, LOGISTIC);
activate_array(l.output + index, 2 * l.w*l.h, LOGISTIC); // x,y,
scal_add_cpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output + index, 1); // scale x,y
index = entry_index(l, b, n*l.w*l.h, 4);
activate_array(l.output + index, (1 + l.classes)*l.w*l.h, LOGISTIC);
}
@ -553,6 +554,7 @@ void forward_yolo_layer_gpu(const layer l, network_state state)
// if(y->1) x -> inf
// if(y->0) x -> -inf
activate_array_ongpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC); // x,y
scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1); // scale x,y
index = entry_index(l, b, n*l.w*l.h, 4);
activate_array_ongpu(l.output_gpu + index, (1+l.classes)*l.w*l.h, LOGISTIC); // classes and objectness
}

Loading…
Cancel
Save