mirror of https://github.com/AlexeyAB/darknet.git
parent
c56931dd75
commit
3ff5084590
30 changed files with 1830 additions and 315 deletions
@ -0,0 +1,5 @@ |
|||||||
|
rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\python.exe darknet_video.py |
||||||
|
|
||||||
|
C:\Python27\python.exe darknet_video.py |
||||||
|
|
||||||
|
pause |
@ -0,0 +1,320 @@ |
|||||||
|
from ctypes import * |
||||||
|
import math |
||||||
|
import random |
||||||
|
import os |
||||||
|
import cv2 |
||||||
|
import numpy as np |
||||||
|
import time |
||||||
|
|
||||||
|
|
||||||
|
def sample(probs): |
||||||
|
s = sum(probs) |
||||||
|
probs = [a/s for a in probs] |
||||||
|
r = random.uniform(0, 1) |
||||||
|
for i in range(len(probs)): |
||||||
|
r = r - probs[i] |
||||||
|
if r <= 0: |
||||||
|
return i |
||||||
|
return len(probs)-1 |
||||||
|
|
||||||
|
|
||||||
|
def c_array(ctype, values): |
||||||
|
arr = (ctype*len(values))() |
||||||
|
arr[:] = values |
||||||
|
return arr |
||||||
|
|
||||||
|
|
||||||
|
class BOX(Structure): |
||||||
|
_fields_ = [("x", c_float), |
||||||
|
("y", c_float), |
||||||
|
("w", c_float), |
||||||
|
("h", c_float)] |
||||||
|
|
||||||
|
|
||||||
|
class DETECTION(Structure): |
||||||
|
_fields_ = [("bbox", BOX), |
||||||
|
("classes", c_int), |
||||||
|
("prob", POINTER(c_float)), |
||||||
|
("mask", POINTER(c_float)), |
||||||
|
("objectness", c_float), |
||||||
|
("sort_class", c_int)] |
||||||
|
|
||||||
|
|
||||||
|
class IMAGE(Structure): |
||||||
|
_fields_ = [("w", c_int), |
||||||
|
("h", c_int), |
||||||
|
("c", c_int), |
||||||
|
("data", POINTER(c_float))] |
||||||
|
|
||||||
|
|
||||||
|
class METADATA(Structure): |
||||||
|
_fields_ = [("classes", c_int), |
||||||
|
("names", POINTER(c_char_p))] |
||||||
|
|
||||||
|
|
||||||
|
hasGPU = True |
||||||
|
|
||||||
|
lib = CDLL("yolo_cpp_dll.dll", RTLD_GLOBAL) |
||||||
|
lib.network_width.argtypes = [c_void_p] |
||||||
|
lib.network_width.restype = c_int |
||||||
|
lib.network_height.argtypes = [c_void_p] |
||||||
|
lib.network_height.restype = c_int |
||||||
|
|
||||||
|
predict = lib.network_predict |
||||||
|
predict.argtypes = [c_void_p, POINTER(c_float)] |
||||||
|
predict.restype = POINTER(c_float) |
||||||
|
|
||||||
|
if hasGPU: |
||||||
|
set_gpu = lib.cuda_set_device |
||||||
|
set_gpu.argtypes = [c_int] |
||||||
|
|
||||||
|
make_image = lib.make_image |
||||||
|
make_image.argtypes = [c_int, c_int, c_int] |
||||||
|
make_image.restype = IMAGE |
||||||
|
|
||||||
|
get_network_boxes = lib.get_network_boxes |
||||||
|
get_network_boxes.argtypes = \ |
||||||
|
[c_void_p, c_int, c_int, c_float, c_float, POINTER( |
||||||
|
c_int), c_int, POINTER(c_int), c_int] |
||||||
|
get_network_boxes.restype = POINTER(DETECTION) |
||||||
|
|
||||||
|
make_network_boxes = lib.make_network_boxes |
||||||
|
make_network_boxes.argtypes = [c_void_p] |
||||||
|
make_network_boxes.restype = POINTER(DETECTION) |
||||||
|
|
||||||
|
free_detections = lib.free_detections |
||||||
|
free_detections.argtypes = [POINTER(DETECTION), c_int] |
||||||
|
|
||||||
|
free_ptrs = lib.free_ptrs |
||||||
|
free_ptrs.argtypes = [POINTER(c_void_p), c_int] |
||||||
|
|
||||||
|
network_predict = lib.network_predict |
||||||
|
network_predict.argtypes = [c_void_p, POINTER(c_float)] |
||||||
|
|
||||||
|
reset_rnn = lib.reset_rnn |
||||||
|
reset_rnn.argtypes = [c_void_p] |
||||||
|
|
||||||
|
load_net = lib.load_network |
||||||
|
load_net.argtypes = [c_char_p, c_char_p, c_int] |
||||||
|
load_net.restype = c_void_p |
||||||
|
|
||||||
|
load_net_custom = lib.load_network_custom |
||||||
|
load_net_custom.argtypes = [c_char_p, c_char_p, c_int, c_int] |
||||||
|
load_net_custom.restype = c_void_p |
||||||
|
|
||||||
|
do_nms_obj = lib.do_nms_obj |
||||||
|
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] |
||||||
|
|
||||||
|
do_nms_sort = lib.do_nms_sort |
||||||
|
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] |
||||||
|
|
||||||
|
free_image = lib.free_image |
||||||
|
free_image.argtypes = [IMAGE] |
||||||
|
|
||||||
|
letterbox_image = lib.letterbox_image |
||||||
|
letterbox_image.argtypes = [IMAGE, c_int, c_int] |
||||||
|
letterbox_image.restype = IMAGE |
||||||
|
|
||||||
|
load_meta = lib.get_metadata |
||||||
|
lib.get_metadata.argtypes = [c_char_p] |
||||||
|
lib.get_metadata.restype = METADATA |
||||||
|
|
||||||
|
load_image = lib.load_image_color |
||||||
|
load_image.argtypes = [c_char_p, c_int, c_int] |
||||||
|
load_image.restype = IMAGE |
||||||
|
|
||||||
|
rgbgr_image = lib.rgbgr_image |
||||||
|
rgbgr_image.argtypes = [IMAGE] |
||||||
|
|
||||||
|
predict_image = lib.network_predict_image |
||||||
|
predict_image.argtypes = [c_void_p, IMAGE] |
||||||
|
predict_image.restype = POINTER(c_float) |
||||||
|
|
||||||
|
|
||||||
|
def array_to_image(arr): |
||||||
|
import numpy as np |
||||||
|
arr = arr.transpose(2, 0, 1) |
||||||
|
c = arr.shape[0] |
||||||
|
h = arr.shape[1] |
||||||
|
w = arr.shape[2] |
||||||
|
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0 |
||||||
|
data = arr.ctypes.data_as(POINTER(c_float)) |
||||||
|
im = IMAGE(w, h, c, data) |
||||||
|
return im, arr |
||||||
|
|
||||||
|
|
||||||
|
def classify(net, meta, im): |
||||||
|
out = predict_image(net, im) |
||||||
|
res = [] |
||||||
|
for i in range(meta.classes): |
||||||
|
if altNames is None: |
||||||
|
nameTag = meta.names[i] |
||||||
|
else: |
||||||
|
nameTag = altNames[i] |
||||||
|
res.append((nameTag, out[i])) |
||||||
|
res = sorted(res, key=lambda x: -x[1]) |
||||||
|
return res |
||||||
|
|
||||||
|
|
||||||
|
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45, debug=False): |
||||||
|
im, arr = array_to_image(image) |
||||||
|
if debug: |
||||||
|
print("Loaded image") |
||||||
|
num = c_int(0) |
||||||
|
if debug: |
||||||
|
print("Assigned num") |
||||||
|
pnum = pointer(num) |
||||||
|
if debug: |
||||||
|
print("Assigned pnum") |
||||||
|
predict_image(net, im) |
||||||
|
if debug: |
||||||
|
print("did prediction") |
||||||
|
# dets = get_network_boxes( |
||||||
|
# net, image.shape[1], image.shape[0], |
||||||
|
# thresh, hier_thresh, |
||||||
|
# None, 0, pnum, 0) # OpenCV |
||||||
|
dets = get_network_boxes(net, im.w, im.h, |
||||||
|
thresh, hier_thresh, None, 0, pnum, 0) |
||||||
|
if debug: |
||||||
|
print("Got dets") |
||||||
|
num = pnum[0] |
||||||
|
if debug: |
||||||
|
print("got zeroth index of pnum") |
||||||
|
if nms: |
||||||
|
do_nms_sort(dets, num, meta.classes, nms) |
||||||
|
if debug: |
||||||
|
print("did sort") |
||||||
|
res = [] |
||||||
|
if debug: |
||||||
|
print("about to range") |
||||||
|
for j in range(num): |
||||||
|
if debug: |
||||||
|
print("Ranging on "+str(j)+" of "+str(num)) |
||||||
|
if debug: |
||||||
|
print("Classes: "+str(meta), meta.classes, meta.names) |
||||||
|
for i in range(meta.classes): |
||||||
|
if debug: |
||||||
|
print("Class-ranging on "+str(i)+" of " + |
||||||
|
str(meta.classes)+"= "+str(dets[j].prob[i])) |
||||||
|
if dets[j].prob[i] > 0: |
||||||
|
b = dets[j].bbox |
||||||
|
if altNames is None: |
||||||
|
nameTag = meta.names[i] |
||||||
|
else: |
||||||
|
nameTag = altNames[i] |
||||||
|
if debug: |
||||||
|
print("Got bbox", b) |
||||||
|
print(nameTag) |
||||||
|
print(dets[j].prob[i]) |
||||||
|
print((b.x, b.y, b.w, b.h)) |
||||||
|
res.append((nameTag, dets[j].prob[i], (b.x, b.y, b.w, b.h))) |
||||||
|
if debug: |
||||||
|
print("did range") |
||||||
|
res = sorted(res, key=lambda x: -x[1]) |
||||||
|
if debug: |
||||||
|
print("did sort") |
||||||
|
# free_image(im) |
||||||
|
if debug: |
||||||
|
print("freed image") |
||||||
|
free_detections(dets, num) |
||||||
|
if debug: |
||||||
|
print("freed detections") |
||||||
|
return res |
||||||
|
|
||||||
|
|
||||||
|
def convertBack(x, y, w, h): |
||||||
|
xmin = int(round(x - (w / 2))) |
||||||
|
xmax = int(round(x + (w / 2))) |
||||||
|
ymin = int(round(y - (h / 2))) |
||||||
|
ymax = int(round(y + (h / 2))) |
||||||
|
return xmin, ymin, xmax, ymax |
||||||
|
|
||||||
|
|
||||||
|
def cvDrawBoxes(detections, img): |
||||||
|
for detection in detections: |
||||||
|
x, y, w, h = detection[2][0],\ |
||||||
|
detection[2][1],\ |
||||||
|
detection[2][2],\ |
||||||
|
detection[2][3] |
||||||
|
xmin, ymin, xmax, ymax = convertBack( |
||||||
|
float(x), float(y), float(w), float(h)) |
||||||
|
pt1 = (xmin, ymin) |
||||||
|
pt2 = (xmax, ymax) |
||||||
|
cv2.rectangle(img, pt1, pt2, (0, 255, 0), 2) |
||||||
|
cv2.putText(img, |
||||||
|
detection[0].decode() + |
||||||
|
" [" + str(round(detection[1] * 100, 2)) + "]", |
||||||
|
(pt1[0], pt1[1] + 20), cv2.FONT_HERSHEY_SIMPLEX, 1, |
||||||
|
[0, 255, 0], 4) |
||||||
|
return img |
||||||
|
|
||||||
|
|
||||||
|
netMain = None |
||||||
|
metaMain = None |
||||||
|
altNames = None |
||||||
|
|
||||||
|
|
||||||
|
def YOLO(): |
||||||
|
global metaMain, netMain, altNames |
||||||
|
configPath = "./cfg/yolov3.cfg" |
||||||
|
weightPath = "./yolov3.weights" |
||||||
|
metaPath = "./cfg/coco.data" |
||||||
|
if not os.path.exists(configPath): |
||||||
|
raise ValueError("Invalid config path `" + |
||||||
|
os.path.abspath(configPath)+"`") |
||||||
|
if not os.path.exists(weightPath): |
||||||
|
raise ValueError("Invalid weight path `" + |
||||||
|
os.path.abspath(weightPath)+"`") |
||||||
|
if not os.path.exists(metaPath): |
||||||
|
raise ValueError("Invalid data file path `" + |
||||||
|
os.path.abspath(metaPath)+"`") |
||||||
|
if netMain is None: |
||||||
|
netMain = load_net_custom(configPath.encode( |
||||||
|
"ascii"), weightPath.encode("ascii"), 0, 1) # batch size = 1 |
||||||
|
if metaMain is None: |
||||||
|
metaMain = load_meta(metaPath.encode("ascii")) |
||||||
|
if altNames is None: |
||||||
|
try: |
||||||
|
with open(metaPath) as metaFH: |
||||||
|
metaContents = metaFH.read() |
||||||
|
import re |
||||||
|
match = re.search("names *= *(.*)$", metaContents, |
||||||
|
re.IGNORECASE | re.MULTILINE) |
||||||
|
if match: |
||||||
|
result = match.group(1) |
||||||
|
else: |
||||||
|
result = None |
||||||
|
try: |
||||||
|
if os.path.exists(result): |
||||||
|
with open(result) as namesFH: |
||||||
|
namesList = namesFH.read().strip().split("\n") |
||||||
|
altNames = [x.strip() for x in namesList] |
||||||
|
except TypeError: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
pass |
||||||
|
#cap = cv2.VideoCapture(0) |
||||||
|
cap = cv2.VideoCapture("test.mp4") |
||||||
|
cap.set(3, 1280) |
||||||
|
cap.set(4, 720) |
||||||
|
out = cv2.VideoWriter( |
||||||
|
"output.avi", cv2.VideoWriter_fourcc(*"MJPG"), 10.0, |
||||||
|
(lib.network_width(netMain), lib.network_height(netMain))) |
||||||
|
print("Starting the YOLO loop...") |
||||||
|
while True: |
||||||
|
prev_time = time.time() |
||||||
|
ret, frame_read = cap.read() |
||||||
|
frame_rgb = cv2.cvtColor(frame_read, cv2.COLOR_BGR2RGB) |
||||||
|
frame_resized = cv2.resize(frame_rgb, |
||||||
|
(lib.network_width(netMain), |
||||||
|
lib.network_height(netMain)), |
||||||
|
interpolation=cv2.INTER_LINEAR) |
||||||
|
detections = detect(netMain, metaMain, frame_resized, thresh=0.25) |
||||||
|
image = cvDrawBoxes(detections, frame_resized) |
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||||
|
print(1/(time.time()-prev_time)) |
||||||
|
cap.release() |
||||||
|
out.release() |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
YOLO() |
@ -0,0 +1,320 @@ |
|||||||
|
from ctypes import * |
||||||
|
import math |
||||||
|
import random |
||||||
|
import os |
||||||
|
import cv2 |
||||||
|
import numpy as np |
||||||
|
import time |
||||||
|
|
||||||
|
|
||||||
|
def sample(probs): |
||||||
|
s = sum(probs) |
||||||
|
probs = [a/s for a in probs] |
||||||
|
r = random.uniform(0, 1) |
||||||
|
for i in range(len(probs)): |
||||||
|
r = r - probs[i] |
||||||
|
if r <= 0: |
||||||
|
return i |
||||||
|
return len(probs)-1 |
||||||
|
|
||||||
|
|
||||||
|
def c_array(ctype, values): |
||||||
|
arr = (ctype*len(values))() |
||||||
|
arr[:] = values |
||||||
|
return arr |
||||||
|
|
||||||
|
|
||||||
|
class BOX(Structure): |
||||||
|
_fields_ = [("x", c_float), |
||||||
|
("y", c_float), |
||||||
|
("w", c_float), |
||||||
|
("h", c_float)] |
||||||
|
|
||||||
|
|
||||||
|
class DETECTION(Structure): |
||||||
|
_fields_ = [("bbox", BOX), |
||||||
|
("classes", c_int), |
||||||
|
("prob", POINTER(c_float)), |
||||||
|
("mask", POINTER(c_float)), |
||||||
|
("objectness", c_float), |
||||||
|
("sort_class", c_int)] |
||||||
|
|
||||||
|
|
||||||
|
class IMAGE(Structure): |
||||||
|
_fields_ = [("w", c_int), |
||||||
|
("h", c_int), |
||||||
|
("c", c_int), |
||||||
|
("data", POINTER(c_float))] |
||||||
|
|
||||||
|
|
||||||
|
class METADATA(Structure): |
||||||
|
_fields_ = [("classes", c_int), |
||||||
|
("names", POINTER(c_char_p))] |
||||||
|
|
||||||
|
|
||||||
|
hasGPU = True |
||||||
|
|
||||||
|
lib = CDLL("./libdarknet.so", RTLD_GLOBAL) |
||||||
|
lib.network_width.argtypes = [c_void_p] |
||||||
|
lib.network_width.restype = c_int |
||||||
|
lib.network_height.argtypes = [c_void_p] |
||||||
|
lib.network_height.restype = c_int |
||||||
|
|
||||||
|
predict = lib.network_predict |
||||||
|
predict.argtypes = [c_void_p, POINTER(c_float)] |
||||||
|
predict.restype = POINTER(c_float) |
||||||
|
|
||||||
|
if hasGPU: |
||||||
|
set_gpu = lib.cuda_set_device |
||||||
|
set_gpu.argtypes = [c_int] |
||||||
|
|
||||||
|
make_image = lib.make_image |
||||||
|
make_image.argtypes = [c_int, c_int, c_int] |
||||||
|
make_image.restype = IMAGE |
||||||
|
|
||||||
|
get_network_boxes = lib.get_network_boxes |
||||||
|
get_network_boxes.argtypes = \ |
||||||
|
[c_void_p, c_int, c_int, c_float, c_float, POINTER( |
||||||
|
c_int), c_int, POINTER(c_int), c_int] |
||||||
|
get_network_boxes.restype = POINTER(DETECTION) |
||||||
|
|
||||||
|
make_network_boxes = lib.make_network_boxes |
||||||
|
make_network_boxes.argtypes = [c_void_p] |
||||||
|
make_network_boxes.restype = POINTER(DETECTION) |
||||||
|
|
||||||
|
free_detections = lib.free_detections |
||||||
|
free_detections.argtypes = [POINTER(DETECTION), c_int] |
||||||
|
|
||||||
|
free_ptrs = lib.free_ptrs |
||||||
|
free_ptrs.argtypes = [POINTER(c_void_p), c_int] |
||||||
|
|
||||||
|
network_predict = lib.network_predict |
||||||
|
network_predict.argtypes = [c_void_p, POINTER(c_float)] |
||||||
|
|
||||||
|
reset_rnn = lib.reset_rnn |
||||||
|
reset_rnn.argtypes = [c_void_p] |
||||||
|
|
||||||
|
load_net = lib.load_network |
||||||
|
load_net.argtypes = [c_char_p, c_char_p, c_int] |
||||||
|
load_net.restype = c_void_p |
||||||
|
|
||||||
|
load_net_custom = lib.load_network_custom |
||||||
|
load_net_custom.argtypes = [c_char_p, c_char_p, c_int, c_int] |
||||||
|
load_net_custom.restype = c_void_p |
||||||
|
|
||||||
|
do_nms_obj = lib.do_nms_obj |
||||||
|
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] |
||||||
|
|
||||||
|
do_nms_sort = lib.do_nms_sort |
||||||
|
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] |
||||||
|
|
||||||
|
free_image = lib.free_image |
||||||
|
free_image.argtypes = [IMAGE] |
||||||
|
|
||||||
|
letterbox_image = lib.letterbox_image |
||||||
|
letterbox_image.argtypes = [IMAGE, c_int, c_int] |
||||||
|
letterbox_image.restype = IMAGE |
||||||
|
|
||||||
|
load_meta = lib.get_metadata |
||||||
|
lib.get_metadata.argtypes = [c_char_p] |
||||||
|
lib.get_metadata.restype = METADATA |
||||||
|
|
||||||
|
load_image = lib.load_image_color |
||||||
|
load_image.argtypes = [c_char_p, c_int, c_int] |
||||||
|
load_image.restype = IMAGE |
||||||
|
|
||||||
|
rgbgr_image = lib.rgbgr_image |
||||||
|
rgbgr_image.argtypes = [IMAGE] |
||||||
|
|
||||||
|
predict_image = lib.network_predict_image |
||||||
|
predict_image.argtypes = [c_void_p, IMAGE] |
||||||
|
predict_image.restype = POINTER(c_float) |
||||||
|
|
||||||
|
|
||||||
|
def array_to_image(arr): |
||||||
|
import numpy as np |
||||||
|
arr = arr.transpose(2, 0, 1) |
||||||
|
c = arr.shape[0] |
||||||
|
h = arr.shape[1] |
||||||
|
w = arr.shape[2] |
||||||
|
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0 |
||||||
|
data = arr.ctypes.data_as(POINTER(c_float)) |
||||||
|
im = IMAGE(w, h, c, data) |
||||||
|
return im, arr |
||||||
|
|
||||||
|
|
||||||
|
def classify(net, meta, im): |
||||||
|
out = predict_image(net, im) |
||||||
|
res = [] |
||||||
|
for i in range(meta.classes): |
||||||
|
if altNames is None: |
||||||
|
nameTag = meta.names[i] |
||||||
|
else: |
||||||
|
nameTag = altNames[i] |
||||||
|
res.append((nameTag, out[i])) |
||||||
|
res = sorted(res, key=lambda x: -x[1]) |
||||||
|
return res |
||||||
|
|
||||||
|
|
||||||
|
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45, debug=False): |
||||||
|
im, arr = array_to_image(image) |
||||||
|
if debug: |
||||||
|
print("Loaded image") |
||||||
|
num = c_int(0) |
||||||
|
if debug: |
||||||
|
print("Assigned num") |
||||||
|
pnum = pointer(num) |
||||||
|
if debug: |
||||||
|
print("Assigned pnum") |
||||||
|
predict_image(net, im) |
||||||
|
if debug: |
||||||
|
print("did prediction") |
||||||
|
# dets = get_network_boxes( |
||||||
|
# net, image.shape[1], image.shape[0], |
||||||
|
# thresh, hier_thresh, |
||||||
|
# None, 0, pnum, 0) # OpenCV |
||||||
|
dets = get_network_boxes(net, im.w, im.h, |
||||||
|
thresh, hier_thresh, None, 0, pnum, 0) |
||||||
|
if debug: |
||||||
|
print("Got dets") |
||||||
|
num = pnum[0] |
||||||
|
if debug: |
||||||
|
print("got zeroth index of pnum") |
||||||
|
if nms: |
||||||
|
do_nms_sort(dets, num, meta.classes, nms) |
||||||
|
if debug: |
||||||
|
print("did sort") |
||||||
|
res = [] |
||||||
|
if debug: |
||||||
|
print("about to range") |
||||||
|
for j in range(num): |
||||||
|
if debug: |
||||||
|
print("Ranging on "+str(j)+" of "+str(num)) |
||||||
|
if debug: |
||||||
|
print("Classes: "+str(meta), meta.classes, meta.names) |
||||||
|
for i in range(meta.classes): |
||||||
|
if debug: |
||||||
|
print("Class-ranging on "+str(i)+" of " + |
||||||
|
str(meta.classes)+"= "+str(dets[j].prob[i])) |
||||||
|
if dets[j].prob[i] > 0: |
||||||
|
b = dets[j].bbox |
||||||
|
if altNames is None: |
||||||
|
nameTag = meta.names[i] |
||||||
|
else: |
||||||
|
nameTag = altNames[i] |
||||||
|
if debug: |
||||||
|
print("Got bbox", b) |
||||||
|
print(nameTag) |
||||||
|
print(dets[j].prob[i]) |
||||||
|
print((b.x, b.y, b.w, b.h)) |
||||||
|
res.append((nameTag, dets[j].prob[i], (b.x, b.y, b.w, b.h))) |
||||||
|
if debug: |
||||||
|
print("did range") |
||||||
|
res = sorted(res, key=lambda x: -x[1]) |
||||||
|
if debug: |
||||||
|
print("did sort") |
||||||
|
# free_image(im) |
||||||
|
if debug: |
||||||
|
print("freed image") |
||||||
|
free_detections(dets, num) |
||||||
|
if debug: |
||||||
|
print("freed detections") |
||||||
|
return res |
||||||
|
|
||||||
|
|
||||||
|
def convertBack(x, y, w, h): |
||||||
|
xmin = int(round(x - (w / 2))) |
||||||
|
xmax = int(round(x + (w / 2))) |
||||||
|
ymin = int(round(y - (h / 2))) |
||||||
|
ymax = int(round(y + (h / 2))) |
||||||
|
return xmin, ymin, xmax, ymax |
||||||
|
|
||||||
|
|
||||||
|
def cvDrawBoxes(detections, img): |
||||||
|
for detection in detections: |
||||||
|
x, y, w, h = detection[2][0],\ |
||||||
|
detection[2][1],\ |
||||||
|
detection[2][2],\ |
||||||
|
detection[2][3] |
||||||
|
xmin, ymin, xmax, ymax = convertBack( |
||||||
|
float(x), float(y), float(w), float(h)) |
||||||
|
pt1 = (xmin, ymin) |
||||||
|
pt2 = (xmax, ymax) |
||||||
|
cv2.rectangle(img, pt1, pt2, (0, 255, 0), 2) |
||||||
|
cv2.putText(img, |
||||||
|
detection[0].decode() + |
||||||
|
" [" + str(round(detection[1] * 100, 2)) + "]", |
||||||
|
(pt1[0], pt1[1] + 20), cv2.FONT_HERSHEY_SIMPLEX, 1, |
||||||
|
[0, 255, 0], 4) |
||||||
|
return img |
||||||
|
|
||||||
|
|
||||||
|
netMain = None |
||||||
|
metaMain = None |
||||||
|
altNames = None |
||||||
|
|
||||||
|
|
||||||
|
def YOLO(): |
||||||
|
global metaMain, netMain, altNames |
||||||
|
configPath = "./cfg/yolov3.cfg" |
||||||
|
weightPath = "./yolov3.weights" |
||||||
|
metaPath = "./cfg/coco.data" |
||||||
|
if not os.path.exists(configPath): |
||||||
|
raise ValueError("Invalid config path `" + |
||||||
|
os.path.abspath(configPath)+"`") |
||||||
|
if not os.path.exists(weightPath): |
||||||
|
raise ValueError("Invalid weight path `" + |
||||||
|
os.path.abspath(weightPath)+"`") |
||||||
|
if not os.path.exists(metaPath): |
||||||
|
raise ValueError("Invalid data file path `" + |
||||||
|
os.path.abspath(metaPath)+"`") |
||||||
|
if netMain is None: |
||||||
|
netMain = load_net_custom(configPath.encode( |
||||||
|
"ascii"), weightPath.encode("ascii"), 0, 1) # batch size = 1 |
||||||
|
if metaMain is None: |
||||||
|
metaMain = load_meta(metaPath.encode("ascii")) |
||||||
|
if altNames is None: |
||||||
|
try: |
||||||
|
with open(metaPath) as metaFH: |
||||||
|
metaContents = metaFH.read() |
||||||
|
import re |
||||||
|
match = re.search("names *= *(.*)$", metaContents, |
||||||
|
re.IGNORECASE | re.MULTILINE) |
||||||
|
if match: |
||||||
|
result = match.group(1) |
||||||
|
else: |
||||||
|
result = None |
||||||
|
try: |
||||||
|
if os.path.exists(result): |
||||||
|
with open(result) as namesFH: |
||||||
|
namesList = namesFH.read().strip().split("\n") |
||||||
|
altNames = [x.strip() for x in namesList] |
||||||
|
except TypeError: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
pass |
||||||
|
#cap = cv2.VideoCapture(0) |
||||||
|
cap = cv2.VideoCapture("test.mp4") |
||||||
|
cap.set(3, 1280) |
||||||
|
cap.set(4, 720) |
||||||
|
out = cv2.VideoWriter( |
||||||
|
"output.avi", cv2.VideoWriter_fourcc(*"MJPG"), 10.0, |
||||||
|
(lib.network_width(netMain), lib.network_height(netMain))) |
||||||
|
print("Starting the YOLO loop...") |
||||||
|
while True: |
||||||
|
prev_time = time.time() |
||||||
|
ret, frame_read = cap.read() |
||||||
|
frame_rgb = cv2.cvtColor(frame_read, cv2.COLOR_BGR2RGB) |
||||||
|
frame_resized = cv2.resize(frame_rgb, |
||||||
|
(lib.network_width(netMain), |
||||||
|
lib.network_height(netMain)), |
||||||
|
interpolation=cv2.INTER_LINEAR) |
||||||
|
detections = detect(netMain, metaMain, frame_resized, thresh=0.25) |
||||||
|
image = cvDrawBoxes(detections, frame_resized) |
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||||
|
print(1/(time.time()-prev_time)) |
||||||
|
cap.release() |
||||||
|
out.release() |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
YOLO() |
@ -0,0 +1,793 @@ |
|||||||
|
#ifndef DARKNET_API |
||||||
|
#define DARKNET_API |
||||||
|
|
||||||
|
#if defined(_MSC_VER) && _MSC_VER < 1900 |
||||||
|
#define inline __inline |
||||||
|
#endif |
||||||
|
|
||||||
|
#include <stdlib.h> |
||||||
|
#include <stdio.h> |
||||||
|
#include <string.h> |
||||||
|
#include <pthread.h> |
||||||
|
#include <stdint.h> |
||||||
|
|
||||||
|
#ifdef LIB_EXPORTS |
||||||
|
#if defined(_MSC_VER) |
||||||
|
#define LIB_API __declspec(dllexport) |
||||||
|
#else |
||||||
|
#define LIB_API __attribute__((visibility("default"))) |
||||||
|
#endif |
||||||
|
#else |
||||||
|
#if defined(_MSC_VER) |
||||||
|
#define LIB_API |
||||||
|
#else |
||||||
|
#define LIB_API |
||||||
|
#endif |
||||||
|
#endif |
||||||
|
|
||||||
|
#ifdef GPU |
||||||
|
#define BLOCK 512 |
||||||
|
|
||||||
|
#include "cuda_runtime.h" |
||||||
|
#include "curand.h" |
||||||
|
#include "cublas_v2.h" |
||||||
|
|
||||||
|
#ifdef CUDNN |
||||||
|
#include "cudnn.h" |
||||||
|
#endif |
||||||
|
#endif |
||||||
|
|
||||||
|
#ifdef __cplusplus |
||||||
|
extern "C" { |
||||||
|
#endif |
||||||
|
|
||||||
|
struct network; |
||||||
|
typedef struct network network; |
||||||
|
|
||||||
|
struct network_state; |
||||||
|
typedef struct network_state; |
||||||
|
|
||||||
|
struct layer; |
||||||
|
typedef struct layer layer; |
||||||
|
|
||||||
|
struct image; |
||||||
|
typedef struct image image; |
||||||
|
|
||||||
|
struct detection; |
||||||
|
typedef struct detection detection; |
||||||
|
|
||||||
|
struct load_args; |
||||||
|
typedef struct load_args load_args; |
||||||
|
|
||||||
|
struct data; |
||||||
|
typedef struct data data; |
||||||
|
|
||||||
|
struct metadata; |
||||||
|
typedef struct metadata metadata; |
||||||
|
|
||||||
|
struct tree; |
||||||
|
typedef struct tree tree; |
||||||
|
|
||||||
|
|
||||||
|
#define SECRET_NUM -1234 |
||||||
|
extern int gpu_index; |
||||||
|
|
||||||
|
// option_list.h
|
||||||
|
typedef struct metadata { |
||||||
|
int classes; |
||||||
|
char **names; |
||||||
|
} metadata; |
||||||
|
|
||||||
|
|
||||||
|
// tree.h
|
||||||
|
typedef struct tree { |
||||||
|
int *leaf; |
||||||
|
int n; |
||||||
|
int *parent; |
||||||
|
int *child; |
||||||
|
int *group; |
||||||
|
char **name; |
||||||
|
|
||||||
|
int groups; |
||||||
|
int *group_size; |
||||||
|
int *group_offset; |
||||||
|
} tree; |
||||||
|
|
||||||
|
|
||||||
|
// activations.h
|
||||||
|
typedef enum { |
||||||
|
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU |
||||||
|
}ACTIVATION; |
||||||
|
|
||||||
|
// image.h
|
||||||
|
typedef enum{ |
||||||
|
PNG, BMP, TGA, JPG |
||||||
|
} IMTYPE; |
||||||
|
|
||||||
|
// activations.h
|
||||||
|
typedef enum{ |
||||||
|
MULT, ADD, SUB, DIV |
||||||
|
} BINARY_ACTIVATION; |
||||||
|
|
||||||
|
// layer.h
|
||||||
|
typedef enum { |
||||||
|
CONVOLUTIONAL, |
||||||
|
DECONVOLUTIONAL, |
||||||
|
CONNECTED, |
||||||
|
MAXPOOL, |
||||||
|
SOFTMAX, |
||||||
|
DETECTION, |
||||||
|
DROPOUT, |
||||||
|
CROP, |
||||||
|
ROUTE, |
||||||
|
COST, |
||||||
|
NORMALIZATION, |
||||||
|
AVGPOOL, |
||||||
|
LOCAL, |
||||||
|
SHORTCUT, |
||||||
|
ACTIVE, |
||||||
|
RNN, |
||||||
|
GRU, |
||||||
|
LSTM, |
||||||
|
CRNN, |
||||||
|
BATCHNORM, |
||||||
|
NETWORK, |
||||||
|
XNOR, |
||||||
|
REGION, |
||||||
|
YOLO, |
||||||
|
ISEG, |
||||||
|
REORG, |
||||||
|
REORG_OLD, |
||||||
|
UPSAMPLE, |
||||||
|
LOGXENT, |
||||||
|
L2NORM, |
||||||
|
BLANK |
||||||
|
} LAYER_TYPE; |
||||||
|
|
||||||
|
// layer.h
|
||||||
|
typedef enum{ |
||||||
|
SSE, MASKED, L1, SEG, SMOOTH,WGAN |
||||||
|
} COST_TYPE; |
||||||
|
|
||||||
|
// layer.h
|
||||||
|
typedef struct update_args { |
||||||
|
int batch; |
||||||
|
float learning_rate; |
||||||
|
float momentum; |
||||||
|
float decay; |
||||||
|
int adam; |
||||||
|
float B1; |
||||||
|
float B2; |
||||||
|
float eps; |
||||||
|
int t; |
||||||
|
} update_args; |
||||||
|
|
||||||
|
// layer.h
|
||||||
|
struct layer { |
||||||
|
LAYER_TYPE type; |
||||||
|
ACTIVATION activation; |
||||||
|
COST_TYPE cost_type; |
||||||
|
void(*forward) (struct layer, struct network_state); |
||||||
|
void(*backward) (struct layer, struct network_state); |
||||||
|
void(*update) (struct layer, int, float, float, float); |
||||||
|
void(*forward_gpu) (struct layer, struct network_state); |
||||||
|
void(*backward_gpu) (struct layer, struct network_state); |
||||||
|
void(*update_gpu) (struct layer, int, float, float, float); |
||||||
|
int batch_normalize; |
||||||
|
int shortcut; |
||||||
|
int batch; |
||||||
|
int forced; |
||||||
|
int flipped; |
||||||
|
int inputs; |
||||||
|
int outputs; |
||||||
|
int nweights; |
||||||
|
int nbiases; |
||||||
|
int extra; |
||||||
|
int truths; |
||||||
|
int h, w, c; |
||||||
|
int out_h, out_w, out_c; |
||||||
|
int n; |
||||||
|
int max_boxes; |
||||||
|
int groups; |
||||||
|
int size; |
||||||
|
int side; |
||||||
|
int stride; |
||||||
|
int reverse; |
||||||
|
int flatten; |
||||||
|
int spatial; |
||||||
|
int pad; |
||||||
|
int sqrt; |
||||||
|
int flip; |
||||||
|
int index; |
||||||
|
int binary; |
||||||
|
int xnor; |
||||||
|
int use_bin_output; |
||||||
|
int steps; |
||||||
|
int hidden; |
||||||
|
int truth; |
||||||
|
float smooth; |
||||||
|
float dot; |
||||||
|
float angle; |
||||||
|
float jitter; |
||||||
|
float saturation; |
||||||
|
float exposure; |
||||||
|
float shift; |
||||||
|
float ratio; |
||||||
|
float learning_rate_scale; |
||||||
|
float clip; |
||||||
|
int focal_loss; |
||||||
|
int noloss; |
||||||
|
int softmax; |
||||||
|
int classes; |
||||||
|
int coords; |
||||||
|
int background; |
||||||
|
int rescore; |
||||||
|
int objectness; |
||||||
|
int does_cost; |
||||||
|
int joint; |
||||||
|
int noadjust; |
||||||
|
int reorg; |
||||||
|
int log; |
||||||
|
int tanh; |
||||||
|
int *mask; |
||||||
|
int total; |
||||||
|
float bflops; |
||||||
|
|
||||||
|
int adam; |
||||||
|
float B1; |
||||||
|
float B2; |
||||||
|
float eps; |
||||||
|
|
||||||
|
int t; |
||||||
|
|
||||||
|
float alpha; |
||||||
|
float beta; |
||||||
|
float kappa; |
||||||
|
|
||||||
|
float coord_scale; |
||||||
|
float object_scale; |
||||||
|
float noobject_scale; |
||||||
|
float mask_scale; |
||||||
|
float class_scale; |
||||||
|
int bias_match; |
||||||
|
int random; |
||||||
|
float ignore_thresh; |
||||||
|
float truth_thresh; |
||||||
|
float thresh; |
||||||
|
float focus; |
||||||
|
int classfix; |
||||||
|
int absolute; |
||||||
|
|
||||||
|
int onlyforward; |
||||||
|
int stopbackward; |
||||||
|
int dontload; |
||||||
|
int dontsave; |
||||||
|
int dontloadscales; |
||||||
|
int numload; |
||||||
|
|
||||||
|
float temperature; |
||||||
|
float probability; |
||||||
|
float scale; |
||||||
|
|
||||||
|
char * cweights; |
||||||
|
int * indexes; |
||||||
|
int * input_layers; |
||||||
|
int * input_sizes; |
||||||
|
int * map; |
||||||
|
int * counts; |
||||||
|
float ** sums; |
||||||
|
float * rand; |
||||||
|
float * cost; |
||||||
|
float * state; |
||||||
|
float * prev_state; |
||||||
|
float * forgot_state; |
||||||
|
float * forgot_delta; |
||||||
|
float * state_delta; |
||||||
|
float * combine_cpu; |
||||||
|
float * combine_delta_cpu; |
||||||
|
|
||||||
|
float *concat; |
||||||
|
float *concat_delta; |
||||||
|
|
||||||
|
float *binary_weights; |
||||||
|
|
||||||
|
float *biases; |
||||||
|
float *bias_updates; |
||||||
|
|
||||||
|
float *scales; |
||||||
|
float *scale_updates; |
||||||
|
|
||||||
|
float *weights; |
||||||
|
float *weight_updates; |
||||||
|
|
||||||
|
char *align_bit_weights_gpu; |
||||||
|
float *mean_arr_gpu; |
||||||
|
float *align_workspace_gpu; |
||||||
|
float *transposed_align_workspace_gpu; |
||||||
|
int align_workspace_size; |
||||||
|
|
||||||
|
char *align_bit_weights; |
||||||
|
float *mean_arr; |
||||||
|
int align_bit_weights_size; |
||||||
|
int lda_align; |
||||||
|
int new_lda; |
||||||
|
int bit_align; |
||||||
|
|
||||||
|
float *col_image; |
||||||
|
float * delta; |
||||||
|
float * output; |
||||||
|
float * loss; |
||||||
|
float * squared; |
||||||
|
float * norms; |
||||||
|
|
||||||
|
float * spatial_mean; |
||||||
|
float * mean; |
||||||
|
float * variance; |
||||||
|
|
||||||
|
float * mean_delta; |
||||||
|
float * variance_delta; |
||||||
|
|
||||||
|
float * rolling_mean; |
||||||
|
float * rolling_variance; |
||||||
|
|
||||||
|
float * x; |
||||||
|
float * x_norm; |
||||||
|
|
||||||
|
float * m; |
||||||
|
float * v; |
||||||
|
|
||||||
|
float * bias_m; |
||||||
|
float * bias_v; |
||||||
|
float * scale_m; |
||||||
|
float * scale_v; |
||||||
|
|
||||||
|
|
||||||
|
float *z_cpu; |
||||||
|
float *r_cpu; |
||||||
|
float *h_cpu; |
||||||
|
float * prev_state_cpu; |
||||||
|
|
||||||
|
float *temp_cpu; |
||||||
|
float *temp2_cpu; |
||||||
|
float *temp3_cpu; |
||||||
|
|
||||||
|
float *dh_cpu; |
||||||
|
float *hh_cpu; |
||||||
|
float *prev_cell_cpu; |
||||||
|
float *cell_cpu; |
||||||
|
float *f_cpu; |
||||||
|
float *i_cpu; |
||||||
|
float *g_cpu; |
||||||
|
float *o_cpu; |
||||||
|
float *c_cpu; |
||||||
|
float *dc_cpu; |
||||||
|
|
||||||
|
float * binary_input; |
||||||
|
|
||||||
|
struct layer *input_layer; |
||||||
|
struct layer *self_layer; |
||||||
|
struct layer *output_layer; |
||||||
|
|
||||||
|
struct layer *reset_layer; |
||||||
|
struct layer *update_layer; |
||||||
|
struct layer *state_layer; |
||||||
|
|
||||||
|
struct layer *input_gate_layer; |
||||||
|
struct layer *state_gate_layer; |
||||||
|
struct layer *input_save_layer; |
||||||
|
struct layer *state_save_layer; |
||||||
|
struct layer *input_state_layer; |
||||||
|
struct layer *state_state_layer; |
||||||
|
|
||||||
|
struct layer *input_z_layer; |
||||||
|
struct layer *state_z_layer; |
||||||
|
|
||||||
|
struct layer *input_r_layer; |
||||||
|
struct layer *state_r_layer; |
||||||
|
|
||||||
|
struct layer *input_h_layer; |
||||||
|
struct layer *state_h_layer; |
||||||
|
|
||||||
|
struct layer *wz; |
||||||
|
struct layer *uz; |
||||||
|
struct layer *wr; |
||||||
|
struct layer *ur; |
||||||
|
struct layer *wh; |
||||||
|
struct layer *uh; |
||||||
|
struct layer *uo; |
||||||
|
struct layer *wo; |
||||||
|
struct layer *uf; |
||||||
|
struct layer *wf; |
||||||
|
struct layer *ui; |
||||||
|
struct layer *wi; |
||||||
|
struct layer *ug; |
||||||
|
struct layer *wg; |
||||||
|
|
||||||
|
tree *softmax_tree; |
||||||
|
|
||||||
|
size_t workspace_size; |
||||||
|
|
||||||
|
#ifdef GPU |
||||||
|
int *indexes_gpu; |
||||||
|
|
||||||
|
float *z_gpu; |
||||||
|
float *r_gpu; |
||||||
|
float *h_gpu; |
||||||
|
|
||||||
|
float *temp_gpu; |
||||||
|
float *temp2_gpu; |
||||||
|
float *temp3_gpu; |
||||||
|
|
||||||
|
float *dh_gpu; |
||||||
|
float *hh_gpu; |
||||||
|
float *prev_cell_gpu; |
||||||
|
float *cell_gpu; |
||||||
|
float *f_gpu; |
||||||
|
float *i_gpu; |
||||||
|
float *g_gpu; |
||||||
|
float *o_gpu; |
||||||
|
float *c_gpu; |
||||||
|
float *dc_gpu; |
||||||
|
|
||||||
|
// adam
|
||||||
|
float *m_gpu; |
||||||
|
float *v_gpu; |
||||||
|
float *bias_m_gpu; |
||||||
|
float *scale_m_gpu; |
||||||
|
float *bias_v_gpu; |
||||||
|
float *scale_v_gpu; |
||||||
|
|
||||||
|
float * combine_gpu; |
||||||
|
float * combine_delta_gpu; |
||||||
|
|
||||||
|
float * prev_state_gpu; |
||||||
|
float * forgot_state_gpu; |
||||||
|
float * forgot_delta_gpu; |
||||||
|
float * state_gpu; |
||||||
|
float * state_delta_gpu; |
||||||
|
float * gate_gpu; |
||||||
|
float * gate_delta_gpu; |
||||||
|
float * save_gpu; |
||||||
|
float * save_delta_gpu; |
||||||
|
float * concat_gpu; |
||||||
|
float * concat_delta_gpu; |
||||||
|
|
||||||
|
float *binary_input_gpu; |
||||||
|
float *binary_weights_gpu; |
||||||
|
|
||||||
|
float * mean_gpu; |
||||||
|
float * variance_gpu; |
||||||
|
|
||||||
|
float * rolling_mean_gpu; |
||||||
|
float * rolling_variance_gpu; |
||||||
|
|
||||||
|
float * variance_delta_gpu; |
||||||
|
float * mean_delta_gpu; |
||||||
|
|
||||||
|
float * col_image_gpu; |
||||||
|
|
||||||
|
float * x_gpu; |
||||||
|
float * x_norm_gpu; |
||||||
|
float * weights_gpu; |
||||||
|
float * weight_updates_gpu; |
||||||
|
float * weight_change_gpu; |
||||||
|
|
||||||
|
float * weights_gpu16; |
||||||
|
float * weight_updates_gpu16; |
||||||
|
|
||||||
|
float * biases_gpu; |
||||||
|
float * bias_updates_gpu; |
||||||
|
float * bias_change_gpu; |
||||||
|
|
||||||
|
float * scales_gpu; |
||||||
|
float * scale_updates_gpu; |
||||||
|
float * scale_change_gpu; |
||||||
|
|
||||||
|
float * output_gpu; |
||||||
|
float * loss_gpu; |
||||||
|
float * delta_gpu; |
||||||
|
float * rand_gpu; |
||||||
|
float * squared_gpu; |
||||||
|
float * norms_gpu; |
||||||
|
#ifdef CUDNN |
||||||
|
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; |
||||||
|
cudnnTensorDescriptor_t srcTensorDesc16, dstTensorDesc16; |
||||||
|
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc; |
||||||
|
cudnnTensorDescriptor_t dsrcTensorDesc16, ddstTensorDesc16; |
||||||
|
cudnnTensorDescriptor_t normTensorDesc, normDstTensorDesc, normDstTensorDescF16; |
||||||
|
cudnnFilterDescriptor_t weightDesc, weightDesc16; |
||||||
|
cudnnFilterDescriptor_t dweightDesc, dweightDesc16; |
||||||
|
cudnnConvolutionDescriptor_t convDesc; |
||||||
|
cudnnConvolutionFwdAlgo_t fw_algo, fw_algo16; |
||||||
|
cudnnConvolutionBwdDataAlgo_t bd_algo, bd_algo16; |
||||||
|
cudnnConvolutionBwdFilterAlgo_t bf_algo, bf_algo16; |
||||||
|
cudnnPoolingDescriptor_t poolingDesc; |
||||||
|
#endif // CUDNN
|
||||||
|
#endif // GPU
|
||||||
|
}; |
||||||
|
|
||||||
|
|
||||||
|
// network.h
|
||||||
|
typedef enum { |
||||||
|
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM |
||||||
|
} learning_rate_policy; |
||||||
|
|
||||||
|
// network.h
|
||||||
|
typedef struct network { |
||||||
|
int n; |
||||||
|
int batch; |
||||||
|
uint64_t *seen; |
||||||
|
int *t; |
||||||
|
float epoch; |
||||||
|
int subdivisions; |
||||||
|
layer *layers; |
||||||
|
float *output; |
||||||
|
learning_rate_policy policy; |
||||||
|
|
||||||
|
float learning_rate; |
||||||
|
float momentum; |
||||||
|
float decay; |
||||||
|
float gamma; |
||||||
|
float scale; |
||||||
|
float power; |
||||||
|
int time_steps; |
||||||
|
int step; |
||||||
|
int max_batches; |
||||||
|
float *scales; |
||||||
|
int *steps; |
||||||
|
int num_steps; |
||||||
|
int burn_in; |
||||||
|
int cudnn_half; |
||||||
|
|
||||||
|
int adam; |
||||||
|
float B1; |
||||||
|
float B2; |
||||||
|
float eps; |
||||||
|
|
||||||
|
int inputs; |
||||||
|
int outputs; |
||||||
|
int truths; |
||||||
|
int notruth; |
||||||
|
int h, w, c; |
||||||
|
int max_crop; |
||||||
|
int min_crop; |
||||||
|
float max_ratio; |
||||||
|
float min_ratio; |
||||||
|
int center; |
||||||
|
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
|
||||||
|
float angle; |
||||||
|
float aspect; |
||||||
|
float exposure; |
||||||
|
float saturation; |
||||||
|
float hue; |
||||||
|
int random; |
||||||
|
int small_object; |
||||||
|
|
||||||
|
int gpu_index; |
||||||
|
tree *hierarchy; |
||||||
|
|
||||||
|
float *input; |
||||||
|
float *truth; |
||||||
|
float *delta; |
||||||
|
float *workspace; |
||||||
|
int train; |
||||||
|
int index; |
||||||
|
float *cost; |
||||||
|
float clip; |
||||||
|
|
||||||
|
#ifdef GPU |
||||||
|
//float *input_gpu;
|
||||||
|
//float *truth_gpu;
|
||||||
|
float *delta_gpu; |
||||||
|
float *output_gpu; |
||||||
|
|
||||||
|
float *input_state_gpu; |
||||||
|
|
||||||
|
float **input_gpu; |
||||||
|
float **truth_gpu; |
||||||
|
float **input16_gpu; |
||||||
|
float **output16_gpu; |
||||||
|
size_t *max_input16_size; |
||||||
|
size_t *max_output16_size; |
||||||
|
int wait_stream; |
||||||
|
#endif |
||||||
|
} network; |
||||||
|
|
||||||
|
// network.h
|
||||||
|
typedef struct network_state { |
||||||
|
float *truth; |
||||||
|
float *input; |
||||||
|
float *delta; |
||||||
|
float *workspace; |
||||||
|
int train; |
||||||
|
int index; |
||||||
|
network net; |
||||||
|
} network_state; |
||||||
|
|
||||||
|
//typedef struct {
|
||||||
|
// int w;
|
||||||
|
// int h;
|
||||||
|
// float scale;
|
||||||
|
// float rad;
|
||||||
|
// float dx;
|
||||||
|
// float dy;
|
||||||
|
// float aspect;
|
||||||
|
//} augment_args;
|
||||||
|
|
||||||
|
// image.h
|
||||||
|
typedef struct image { |
||||||
|
int w; |
||||||
|
int h; |
||||||
|
int c; |
||||||
|
float *data; |
||||||
|
} image; |
||||||
|
|
||||||
|
//typedef struct {
|
||||||
|
// int w;
|
||||||
|
// int h;
|
||||||
|
// int c;
|
||||||
|
// float *data;
|
||||||
|
//} image;
|
||||||
|
|
||||||
|
// box.h
|
||||||
|
typedef struct box { |
||||||
|
float x, y, w, h; |
||||||
|
} box; |
||||||
|
|
||||||
|
// box.h
|
||||||
|
typedef struct detection{ |
||||||
|
box bbox; |
||||||
|
int classes; |
||||||
|
float *prob; |
||||||
|
float *mask; |
||||||
|
float objectness; |
||||||
|
int sort_class; |
||||||
|
} detection; |
||||||
|
|
||||||
|
// matrix.h
|
||||||
|
typedef struct matrix { |
||||||
|
int rows, cols; |
||||||
|
float **vals; |
||||||
|
} matrix; |
||||||
|
|
||||||
|
// data.h
|
||||||
|
typedef struct data { |
||||||
|
int w, h; |
||||||
|
matrix X; |
||||||
|
matrix y; |
||||||
|
int shallow; |
||||||
|
int *num_boxes; |
||||||
|
box **boxes; |
||||||
|
} data; |
||||||
|
|
||||||
|
// data.h
|
||||||
|
typedef enum { |
||||||
|
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA, ISEG_DATA |
||||||
|
} data_type; |
||||||
|
|
||||||
|
// data.h
|
||||||
|
typedef struct load_args { |
||||||
|
int threads; |
||||||
|
char **paths; |
||||||
|
char *path; |
||||||
|
int n; |
||||||
|
int m; |
||||||
|
char **labels; |
||||||
|
int h; |
||||||
|
int w; |
||||||
|
int c; // color depth
|
||||||
|
int out_w; |
||||||
|
int out_h; |
||||||
|
int nh; |
||||||
|
int nw; |
||||||
|
int num_boxes; |
||||||
|
int min, max, size; |
||||||
|
int classes; |
||||||
|
int background; |
||||||
|
int scale; |
||||||
|
int center; |
||||||
|
int coords; |
||||||
|
int small_object; |
||||||
|
float jitter; |
||||||
|
int flip; |
||||||
|
float angle; |
||||||
|
float aspect; |
||||||
|
float saturation; |
||||||
|
float exposure; |
||||||
|
float hue; |
||||||
|
data *d; |
||||||
|
image *im; |
||||||
|
image *resized; |
||||||
|
data_type type; |
||||||
|
tree *hierarchy; |
||||||
|
} load_args; |
||||||
|
|
||||||
|
// data.h
|
||||||
|
typedef struct box_label { |
||||||
|
int id; |
||||||
|
float x, y, w, h; |
||||||
|
float left, right, top, bottom; |
||||||
|
} box_label; |
||||||
|
|
||||||
|
// list.h
|
||||||
|
//typedef struct node {
|
||||||
|
// void *val;
|
||||||
|
// struct node *next;
|
||||||
|
// struct node *prev;
|
||||||
|
//} node;
|
||||||
|
|
||||||
|
// list.h
|
||||||
|
//typedef struct list {
|
||||||
|
// int size;
|
||||||
|
// node *front;
|
||||||
|
// node *back;
|
||||||
|
//} list;
|
||||||
|
|
||||||
|
// -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
// parser.c
|
||||||
|
LIB_API network *load_network(char *cfg, char *weights, int clear); |
||||||
|
LIB_API network *load_network_custom(char *cfg, char *weights, int clear, int batch); |
||||||
|
LIB_API network *load_network(char *cfg, char *weights, int clear); |
||||||
|
|
||||||
|
// network.c
|
||||||
|
LIB_API load_args get_base_args(network *net); |
||||||
|
|
||||||
|
// box.h
|
||||||
|
LIB_API void do_nms_sort(detection *dets, int total, int classes, float thresh); |
||||||
|
LIB_API void do_nms_obj(detection *dets, int total, int classes, float thresh); |
||||||
|
|
||||||
|
// network.h
|
||||||
|
LIB_API float *network_predict(network net, float *input); |
||||||
|
LIB_API detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter); |
||||||
|
LIB_API void free_detections(detection *dets, int n); |
||||||
|
LIB_API void fuse_conv_batchnorm(network net); |
||||||
|
LIB_API void calculate_binary_weights(network net); |
||||||
|
|
||||||
|
LIB_API layer* get_network_layer(network* net, int i); |
||||||
|
LIB_API detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter); |
||||||
|
LIB_API detection *make_network_boxes(network *net, float thresh, int *num); |
||||||
|
LIB_API void reset_rnn(network *net); |
||||||
|
LIB_API float *network_predict_image(network *net, image im); |
||||||
|
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net); |
||||||
|
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map); |
||||||
|
LIB_API int network_width(network *net); |
||||||
|
LIB_API int network_height(network *net); |
||||||
|
LIB_API void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm); |
||||||
|
|
||||||
|
// image.h
|
||||||
|
LIB_API image resize_image(image im, int w, int h); |
||||||
|
LIB_API image letterbox_image(image im, int w, int h); |
||||||
|
LIB_API void rgbgr_image(image im); |
||||||
|
LIB_API image make_image(int w, int h, int c); |
||||||
|
LIB_API image load_image_color(char *filename, int w, int h); |
||||||
|
LIB_API void free_image(image m); |
||||||
|
|
||||||
|
// layer.h
|
||||||
|
LIB_API void free_layer(layer); |
||||||
|
|
||||||
|
// data.c
|
||||||
|
LIB_API void free_data(data d); |
||||||
|
LIB_API pthread_t load_data(load_args args); |
||||||
|
LIB_API pthread_t load_data_in_thread(load_args args); |
||||||
|
|
||||||
|
// cuda.h
|
||||||
|
LIB_API void cuda_pull_array(float *x_gpu, float *x, size_t n); |
||||||
|
LIB_API void cuda_set_device(int n); |
||||||
|
|
||||||
|
// utils.h
|
||||||
|
LIB_API void free_ptrs(void **ptrs, int n); |
||||||
|
LIB_API void top_k(float *a, int n, int k, int *index); |
||||||
|
|
||||||
|
// tree.h
|
||||||
|
LIB_API tree *read_tree(char *filename); |
||||||
|
|
||||||
|
// option_list.h
|
||||||
|
LIB_API metadata get_metadata(char *file); |
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus |
||||||
|
} |
||||||
|
#endif // __cplusplus
|
||||||
|
#endif // DARKNET_API
|
Loading…
Reference in new issue