Use ignore_thresh only if class_id matched. Temporary changed Assisted_Excitation (reduces background activations rather than enhancing objects activations). Added antialiasiong=2 for 2x2.

pull/4140/head
AlexeyAB 6 years ago
parent 2eb68d5177
commit e6486ab594
  1. 40
      src/convolutional_kernels.cu
  2. 20
      src/convolutional_layer.c
  3. 3
      src/detector.c
  4. 2
      src/http_stream.cpp
  5. 19
      src/maxpool_layer.c
  6. 32
      src/yolo_layer.c

@ -948,6 +948,30 @@ void assisted_activation_gpu(float alpha, float *output, float *gt_gpu, float *a
assisted_activation_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (alpha, output, gt_gpu, a_avg_gpu, size, channels, batches);
}
__global__ void assisted_activation2_kernel(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int xy = i % size;
int b = i / size;
float beta = 1 - alpha;
if (b < batches) {
for (int c = 0; c < channels; ++c) {
if(gt_gpu[i] == 0)
output[xy + size*(c + channels*b)] *= beta;
}
}
}
void assisted_activation2_gpu(float alpha, float *output, float *gt_gpu, float *a_avg_gpu, int size, int channels, int batches)
{
const int num_blocks = get_number_of_blocks(size*batches, BLOCK);
assisted_activation2_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (alpha, output, gt_gpu, a_avg_gpu, size, channels, batches);
}
void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
{
const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
@ -958,12 +982,13 @@ void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
// calculate alpha
//const float alpha = (1 + cos(3.141592 * iteration_num)) / (2 * state.net.max_batches);
//const float alpha = (1 + cos(3.141592 * epoch)) / (2 * state.net.max_batches);
//const float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches)) / 2;
float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches));
float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches)) / 2;
//float alpha = (1 + cos(3.141592 * iteration_num / state.net.max_batches));
if (l.assisted_excitation > 1) {
if (iteration_num > l.assisted_excitation) alpha = 0;
else alpha = (1 + cos(3.141592 * iteration_num / l.assisted_excitation));
if (iteration_num < state.net.burn_in) alpha = 0;
else if (iteration_num > l.assisted_excitation) alpha = 0;
else alpha = (1 + cos(3.141592 * iteration_num / l.assisted_excitation)) / 2;
}
//printf("\n epoch = %f, alpha = %f, seen = %d, max_batches = %d, train_images_num = %d \n",
@ -1017,7 +1042,8 @@ void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
//CHECK_CUDA(cudaPeekAtLastError());
// calc new output
assisted_activation_gpu(alpha, l.output_gpu, l.gt_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch);
assisted_activation2_gpu(alpha, l.output_gpu, l.gt_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch);
//assisted_activation_gpu(alpha, l.output_gpu, l.gt_gpu, l.a_avg_gpu, l.out_w * l.out_h, l.out_c, l.batch);
//cudaStreamSynchronize(get_cuda_stream());
//CHECK_CUDA(cudaPeekAtLastError());
@ -1070,13 +1096,13 @@ void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
printf(" Assisted Excitation alpha = %f \n", alpha);
image img = float_to_image(l.out_w, l.out_h, 1, &gt[l.out_w*l.out_h*b]);
char buff[100];
sprintf(buff, "a_excitation_%d", b);
sprintf(buff, "a_excitation_gt_%d", b);
show_image_cv(img, buff);
//image img2 = float_to_image(l.out_w, l.out_h, 1, &l.output[l.out_w*l.out_h*l.out_c*b]);
image img2 = float_to_image_scaled(l.out_w, l.out_h, 1, &l.output[l.out_w*l.out_h*l.out_c*b]);
char buff2[100];
sprintf(buff2, "a_excitation_act_%d", b);
sprintf(buff2, "a_excitation_output_%d", b);
show_image_cv(img2, buff2);
/*

@ -587,10 +587,24 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
if (l.antialiasing) {
printf("AA: ");
l.input_layer = (layer*)calloc(1, sizeof(layer));
const int blur_size = 3;
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0);
int blur_size = 3;
int blur_pad = blur_size / 2;
if (l.antialiasing == 2) {
blur_size = 2;
blur_pad = 0;
}
*(l.input_layer) = make_convolutional_layer(batch, steps, out_h, out_w, n, n, n, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, index, 0, NULL, 0);
const int blur_nweights = n * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
if (blur_size == 2) {
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {
l.input_layer->weights[i + 0] = 1 / 4.f;
l.input_layer->weights[i + 1] = 1 / 4.f;
l.input_layer->weights[i + 2] = 1 / 4.f;
l.input_layer->weights[i + 3] = 1 / 4.f;
}
}
else {
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {
/*
l.input_layer->weights[i + 0] = 0;
@ -616,7 +630,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.input_layer->weights[i + 6] = 1 / 16.f;
l.input_layer->weights[i + 7] = 2 / 16.f;
l.input_layer->weights[i + 8] = 1 / 16.f;
}
}
for (i = 0; i < n; ++i) l.input_layer->biases[i] = 0;
#ifdef GPU

@ -798,7 +798,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
replace_image_to_label(path, labelpath);
int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
int i, j;
int j;
for (j = 0; j < num_labels; ++j) {
truth_classes_count[truth[j].id]++;
}
@ -818,6 +818,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
const int checkpoint_detections_count = detections_count;
int i;
for (i = 0; i < nboxes; ++i) {
int class_id;

@ -48,7 +48,7 @@ static int close_socket(SOCKET s) {
cerr << "Close socket: out = " << close_output << ", in = " << close_input << " \n";
return result;
}
#else // nix
#else // _WIN32 - else: nix
#include "darkunistd.h"
#include <sys/time.h>
#include <sys/types.h>

@ -108,10 +108,24 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
if (l.antialiasing) {
printf("AA: ");
l.input_layer = (layer*)calloc(1, sizeof(layer));
const int blur_size = 3;
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_size / 2, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0);
int blur_size = 3;
int blur_pad = blur_size / 2;
if (l.antialiasing == 2) {
blur_size = 2;
blur_pad = 0;
}
*(l.input_layer) = make_convolutional_layer(batch, 1, l.out_h, l.out_w, l.out_c, l.out_c, l.out_c, blur_size, blur_stride_x, blur_stride_y, 1, blur_pad, LINEAR, 0, 0, 0, 0, 0, 1, 0, NULL, 0);
const int blur_nweights = l.out_c * blur_size * blur_size; // (n / n) * n * blur_size * blur_size;
int i;
if (blur_size == 2) {
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {
l.input_layer->weights[i + 0] = 1 / 4.f;
l.input_layer->weights[i + 1] = 1 / 4.f;
l.input_layer->weights[i + 2] = 1 / 4.f;
l.input_layer->weights[i + 3] = 1 / 4.f;
}
}
else {
for (i = 0; i < blur_nweights; i += (blur_size*blur_size)) {
/*
l.input_layer->weights[i + 0] = 0;
@ -138,6 +152,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
l.input_layer->weights[i + 7] = 2 / 16.f;
l.input_layer->weights[i + 8] = 1 / 16.f;
}
}
for (i = 0; i < l.out_c; ++i) l.input_layer->biases[i] = 0;
#ifdef GPU
l.input_antialiasing_gpu = cuda_make_array(NULL, l.batch*l.outputs);

@ -128,6 +128,26 @@ box get_yolo_box(float *x, float *biases, int n, int index, int i, int j, int lw
return b;
}
int get_yolo_class(float *output, int classes, int class_index, int stride, float objectness)
{
int class_id = 0;
float max_prob = FLT_MIN;
int j;
for (j = 0; j < classes; ++j) {
float prob = objectness * output[class_index + stride*j];
if (prob > max_prob) {
max_prob = prob;
class_id = j;
}
//int class_index = entry_index(l, 0, n*l.w*l.h + i, 4 + 1 + j);
//float prob = objectness*predictions[class_index];
//dets[count].prob[j] = (prob > thresh) ? prob : 0;
}
return class_id;
}
ious delta_yolo_box(box truth, float *x, float *biases, int n, int index, int i, int j, int lw, int lh, int w, int h, float *delta, float scale, int stride, float iou_normalizer, IOU_LOSS iou_loss)
{
ious all_ious = { 0 };
@ -272,6 +292,7 @@ void forward_yolo_layer(const layer l, network_state state)
box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h);
float best_iou = 0;
int best_t = 0;
int class_id_match = 0;
for (t = 0; t < l.max_boxes; ++t) {
box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
@ -282,8 +303,17 @@ void forward_yolo_layer(const layer l, network_state state)
continue; // if label contains class_id more than number of classes in the cfg-file
}
if (!truth.x) break; // continue;
int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
float objectness = l.output[obj_index];
int pred_class_id = get_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness);
if (class_id == pred_class_id) class_id_match = 1;
else class_id_match = 0;
float iou = box_iou(pred, truth);
if (iou > best_iou) {
//if (iou > best_iou) {
if (iou > best_iou && class_id_match == 1) {
best_iou = iou;
best_t = t;
}

Loading…
Cancel
Save