Batch inference added

pull/4099/head
enes 6 years ago
parent 53160fa666
commit 1108eafa95
  1. 81
      darknet.py
  2. 6
      include/darknet.h
  3. 81
      src/network.c
  4. 48
      src/yolo_layer.c
  5. 2
      src/yolo_layer.h

@ -61,6 +61,9 @@ class DETECTION(Structure):
("objectness", c_float), ("objectness", c_float),
("sort_class", c_int)] ("sort_class", c_int)]
class DETNUMPAIR(Structure):
_fields_ = [("num", c_int),
("dets", POINTER(DETECTION))]
class IMAGE(Structure): class IMAGE(Structure):
_fields_ = [("w", c_int), _fields_ = [("w", c_int),
@ -157,6 +160,9 @@ make_network_boxes.restype = POINTER(DETECTION)
free_detections = lib.free_detections free_detections = lib.free_detections
free_detections.argtypes = [POINTER(DETECTION), c_int] free_detections.argtypes = [POINTER(DETECTION), c_int]
free_batch_detections = lib.free_batch_detections
free_batch_detections.argtypes = [POINTER(DETNUMPAIR), c_int]
free_ptrs = lib.free_ptrs free_ptrs = lib.free_ptrs
free_ptrs.argtypes = [POINTER(c_void_p), c_int] free_ptrs.argtypes = [POINTER(c_void_p), c_int]
@ -206,6 +212,11 @@ predict_image_letterbox = lib.network_predict_image_letterbox
predict_image_letterbox.argtypes = [c_void_p, IMAGE] predict_image_letterbox.argtypes = [c_void_p, IMAGE]
predict_image_letterbox.restype = POINTER(c_float) predict_image_letterbox.restype = POINTER(c_float)
network_predict_custom = lib.network_predict_custom
network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
c_float, c_float, POINTER(c_int), c_int, c_int]
network_predict_custom.restype = POINTER(DETNUMPAIR)
def array_to_image(arr): def array_to_image(arr):
import numpy as np import numpy as np
# need to return old values to avoid python freeing memory # need to return old values to avoid python freeing memory
@ -441,5 +452,75 @@ def performDetect(imagePath="data/dog.jpg", thresh= 0.25, configPath = "./cfg/yo
print("Unable to show image: "+str(e)) print("Unable to show image: "+str(e))
return detections return detections
def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath = "yolov3.weights", metaPath= "./cfg/coco.data", hier_thresh=.5, nms=.45, batch_size=3):
import cv2
import numpy as np
# NB! Image sizes should be the same
# You can change the images, yet, be sure that they have the same width and height
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
image_list = [cv2.imread(k) for k in img_samples]
if len(image_list) > batch_size:
raise ValueError(
"Please check if batch size is equal to the number of images passed to the function")
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
meta = load_meta(metaPath.encode('utf-8'))
pred_height, pred_width, c = image_list[0].shape
net_width, net_height = (network_width(net), network_height(net))
img_list = []
for custom_image_bgr in image_list:
custom_image = cv2.cvtColor(custom_image_bgr, cv2.COLOR_BGR2RGB)
custom_image = cv2.resize(
custom_image, (net_width, net_height), interpolation=cv2.INTER_NEAREST)
custom_image = custom_image.transpose(2, 0, 1)
img_list.append(custom_image)
arr = np.concatenate(img_list, axis=0)
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(net_width, net_height, c, data)
batch_dets = network_predict_custom(net, im, batch_size, pred_width,
pred_height, thresh, hier_thresh, None, 0, 0)
batch_boxes = []
batch_scores = []
batch_classes = []
for b in range(batch_size):
num = batch_dets[b].num
dets = batch_dets[b].dets
if nms:
do_nms_obj(dets, num, meta.classes, nms)
boxes = []
scores = []
classes = []
for i in range(num):
det = dets[i]
score = -1
label = None
for c in range(det.classes):
p = det.prob[c]
if p > score:
score = p
label = c
if score > thresh:
box = det.bbox
left, top, right, bottom = map(int,(box.x - box.w / 2, box.y - box.h / 2,
box.x + box.w / 2, box.y + box.h / 2))
boxes.append((top, left, bottom, right))
scores.append(score)
classes.append(label)
boxColor = (int(255 * (1 - (score ** 2))), int(255 * (score ** 2)), 0)
cv2.rectangle(image_list[b], (left, top),
(right, bottom), boxColor, 2)
cv2.imwrite(os.path.basename(img_samples[b]),image_list[b])
batch_boxes.append(boxes)
batch_scores.append(scores)
batch_classes.append(classes)
free_batch_detections(batch_dets, batch_size)
return batch_boxes, batch_scores, batch_classes
if __name__ == "__main__": if __name__ == "__main__":
print(performDetect()) print(performDetect())
# Uncomment the following line to see batch inference working
#print(performBatchDetect())

@ -730,6 +730,12 @@ typedef struct detection{
int sort_class; int sort_class;
} detection; } detection;
// network.c -batch inference
typedef struct detNumPair {
int num;
detection *dets;
} detNumPair, *pdetNumPair;
// matrix.h // matrix.h
typedef struct matrix { typedef struct matrix {
int rows, cols; int rows, cols;

@ -694,6 +694,22 @@ int num_detections(network *net, float thresh)
return s; return s;
} }
int num_detections_custom(network *net, float thresh, int b)
{
int i;
int s = 0;
for (i = 0; i < net->n; ++i) {
layer l = net->layers[i];
if (l.type == YOLO) {
s += yolo_num_detections_custom(l, thresh, b);
}
if (l.type == DETECTION || l.type == REGION) {
s += l.w*l.h*l.n;
}
}
return s;
}
detection *make_network_boxes(network *net, float thresh, int *num) detection *make_network_boxes(network *net, float thresh, int *num)
{ {
layer l = net->layers[net->n - 1]; layer l = net->layers[net->n - 1];
@ -710,6 +726,21 @@ detection *make_network_boxes(network *net, float thresh, int *num)
return dets; return dets;
} }
detection *make_network_boxes_custom(network *net, float thresh, int *num, int batch)
{
int i;
layer l = net->layers[net->n - 1];
int nboxes = num_detections_custom(net, thresh, batch);
if (num) *num = nboxes;
detection* dets = (detection*)calloc(nboxes, sizeof(detection));
for (i = 0; i < nboxes; ++i) {
dets[i].prob = (float*)calloc(l.classes, sizeof(float));
if (l.coords > 4) {
dets[i].mask = (float*)calloc(l.coords - 4, sizeof(float));
}
}
return dets;
}
void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, float thresh, int *map, float hier, int relative, detection *dets, int letter) void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, float thresh, int *map, float hier, int relative, detection *dets, int letter)
{ {
@ -761,6 +792,33 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in
} }
} }
void fill_network_boxes_custom(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch)
{
int prev_classes = -1;
int j;
for (j = 0; j < net->n; ++j) {
layer l = net->layers[j];
if (l.type == YOLO) {
int count = get_yolo_detections_custom(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch);
dets += count;
if (prev_classes < 0) prev_classes = l.classes;
else if (prev_classes != l.classes) {
printf(" Error: Different [yolo] layers have different number of classes = %d and %d - check your cfg-file! \n",
prev_classes, l.classes);
}
}
if (l.type == REGION) {
custom_get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets, letter);
//get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
dets += l.w*l.h*l.n;
}
if (l.type == DETECTION) {
get_detection_detections(l, w, h, thresh, dets);
dets += l.w*l.h*l.n;
}
}
}
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter) detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter)
{ {
detection *dets = make_network_boxes(net, thresh, num); detection *dets = make_network_boxes(net, thresh, num);
@ -778,6 +836,14 @@ void free_detections(detection *dets, int n)
free(dets); free(dets);
} }
void free_batch_detections(detNumPair *detNumPairs, int n)
{
int i;
for(i=0; i<n; ++i)
free_detections(detNumPairs[i].dets,detNumPairs[i].num);
free(detNumPairs);
}
// JSON format: // JSON format:
//{ //{
// "frame_id":8990, // "frame_id":8990,
@ -849,6 +915,21 @@ float *network_predict_image(network *net, image im)
return p; return p;
} }
detNumPair* network_predict_custom(network *net, image im, int batch, int w, int h, float thresh, float hier, int *map, int relative, int letter)
{
set_batch_network(net, batch);
network_predict(*net, im.data);
detNumPair *pdets = malloc(batch*sizeof(detNumPair));
int num;
for(int b=0;b<batch;b++){
detection *dets = make_network_boxes_custom(net, thresh, &num, b);
fill_network_boxes_custom(net, w, h, thresh, hier, map, relative, dets, letter,b);
pdets[b].num = num;
pdets[b].dets = dets;
}
return pdets;
}
float *network_predict_image_letterbox(network *net, image im) float *network_predict_image_letterbox(network *net, image im)
{ {
//image imr = letterbox_image(im, net->w, net->h); //image imr = letterbox_image(im, net->w, net->h);

@ -461,6 +461,21 @@ int yolo_num_detections(layer l, float thresh)
return count; return count;
} }
int yolo_num_detections_custom(layer l, float thresh, int batch)
{
int i, n;
int count = 0;
for (i = 0; i < l.w*l.h; ++i){
for(n = 0; n < l.n; ++n){
int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4);
if(l.output[obj_index] > thresh){
++count;
}
}
}
return count;
}
void avg_flipped_yolo(layer l) void avg_flipped_yolo(layer l)
{ {
int i,j,n,z; int i,j,n,z;
@ -522,6 +537,39 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh,
return count; return count;
} }
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter,int batch)
{
//printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n);
int i,j,n;
float *predictions = l.output;
//if (l.batch == 2) avg_flipped_yolo(l);
int count = 0;
for (i = 0; i < l.w*l.h; ++i){
int row = i / l.w;
int col = i % l.w;
for(n = 0; n < l.n; ++n){
int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4);
float objectness = predictions[obj_index];
//if(objectness <= thresh) continue; // incorrect behavior for Nan values
if (objectness > thresh) {
//printf("\n objectness = %f, thresh = %f, i = %d, n = %d \n", objectness, thresh, i, n);
int box_index = entry_index(l, batch, n*l.w*l.h + i, 0);
dets[count].bbox = get_yolo_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h);
dets[count].objectness = objectness;
dets[count].classes = l.classes;
for (j = 0; j < l.classes; ++j) {
int class_index = entry_index(l, batch, n*l.w*l.h + i, 4 + 1 + j);
float prob = objectness*predictions[class_index];
dets[count].prob[j] = (prob > thresh) ? prob : 0;
}
++count;
}
}
}
correct_yolo_boxes(dets, count, w, h, netw, neth, relative, letter);
return count;
}
#ifdef GPU #ifdef GPU
void forward_yolo_layer_gpu(const layer l, network_state state) void forward_yolo_layer_gpu(const layer l, network_state state)

@ -13,7 +13,9 @@ void forward_yolo_layer(const layer l, network_state state);
void backward_yolo_layer(const layer l, network_state state); void backward_yolo_layer(const layer l, network_state state);
void resize_yolo_layer(layer *l, int w, int h); void resize_yolo_layer(layer *l, int w, int h);
int yolo_num_detections(layer l, float thresh); int yolo_num_detections(layer l, float thresh);
int yolo_num_detections_custom(layer l, float thresh, int batch);
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter); int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter);
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch);
void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter); void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter);
#ifdef GPU #ifdef GPU

Loading…
Cancel
Save