|
|
|
@ -71,11 +71,11 @@ void draw_detection(image im, float *box, int side) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void train_detection_net() |
|
|
|
|
void train_detection_net(char *cfgfile) |
|
|
|
|
{ |
|
|
|
|
float avg_loss = 1; |
|
|
|
|
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
|
|
|
|
|
network net = parse_network_cfg("cfg/detnet.cfg"); |
|
|
|
|
network net = parse_network_cfg(cfgfile); |
|
|
|
|
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
|
|
|
|
int imgs = 1024; |
|
|
|
|
srand(time(0)); |
|
|
|
@ -115,6 +115,57 @@ void train_detection_net() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void validate_detection_net(char *cfgfile) |
|
|
|
|
{ |
|
|
|
|
network net = parse_network_cfg(cfgfile); |
|
|
|
|
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
|
|
|
|
srand(time(0)); |
|
|
|
|
|
|
|
|
|
list *plist = get_paths("/home/pjreddie/data/imagenet/detection.val"); |
|
|
|
|
char **paths = (char **)list_to_array(plist); |
|
|
|
|
|
|
|
|
|
int m = plist->size; |
|
|
|
|
int i = 0; |
|
|
|
|
int splits = 50; |
|
|
|
|
int num = (i+1)*m/splits - i*m/splits; |
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%d\n", m); |
|
|
|
|
data val, buffer; |
|
|
|
|
pthread_t load_thread = load_data_thread(paths, num, 0, 0, 245, 224, 224, &buffer); |
|
|
|
|
clock_t time; |
|
|
|
|
for(i = 1; i <= splits; ++i){ |
|
|
|
|
time=clock(); |
|
|
|
|
pthread_join(load_thread, 0); |
|
|
|
|
val = buffer; |
|
|
|
|
normalize_data_rows(val); |
|
|
|
|
|
|
|
|
|
num = (i+1)*m/splits - i*m/splits; |
|
|
|
|
char **part = paths+(i*m/splits); |
|
|
|
|
if(i != splits) load_thread = load_data_thread(part, num, 0, 0, 245, 224, 224, &buffer); |
|
|
|
|
|
|
|
|
|
fprintf(stderr, "Loaded: %lf seconds\n", sec(clock()-time)); |
|
|
|
|
matrix pred = network_predict_data(net, val); |
|
|
|
|
int j, k; |
|
|
|
|
for(j = 0; j < pred.rows; ++j){ |
|
|
|
|
for(k = 0; k < pred.cols; k += 5){ |
|
|
|
|
if (pred.vals[j][k] > .005){ |
|
|
|
|
int index = k/5;
|
|
|
|
|
int r = index/7; |
|
|
|
|
int c = index%7; |
|
|
|
|
float y = (32.*(r + pred.vals[j][k+1]))/224.; |
|
|
|
|
float x = (32.*(c + pred.vals[j][k+2]))/224.; |
|
|
|
|
float h = (256.*(pred.vals[j][k+3]))/224.; |
|
|
|
|
float w = (256.*(pred.vals[j][k+4]))/224.; |
|
|
|
|
printf("%d %f %f %f %f %f\n", (i-1)*m/splits + j + 1, pred.vals[j][k], y, x, h, w); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
time=clock(); |
|
|
|
|
free_data(val); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void train_imagenet_distributed(char *address) |
|
|
|
|
{ |
|
|
|
|
float avg_loss = 1; |
|
|
|
@ -159,10 +210,10 @@ void train_imagenet(char *cfgfile) |
|
|
|
|
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
|
|
|
|
|
srand(time(0)); |
|
|
|
|
network net = parse_network_cfg(cfgfile); |
|
|
|
|
//set_learning_network(&net, net.learning_rate, 0, .0005);
|
|
|
|
|
set_learning_network(&net, net.learning_rate, 0, net.decay); |
|
|
|
|
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
|
|
|
|
int imgs = 1024; |
|
|
|
|
int i = 47900; |
|
|
|
|
int i = 77700; |
|
|
|
|
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list"); |
|
|
|
|
list *plist = get_paths("/data/imagenet/cls.train.list"); |
|
|
|
|
char **paths = (char **)list_to_array(plist); |
|
|
|
@ -177,7 +228,9 @@ void train_imagenet(char *cfgfile) |
|
|
|
|
time=clock(); |
|
|
|
|
pthread_join(load_thread, 0); |
|
|
|
|
train = buffer; |
|
|
|
|
normalize_data_rows(train); |
|
|
|
|
//normalize_data_rows(train);
|
|
|
|
|
translate_data_rows(train, -128); |
|
|
|
|
scale_data_rows(train, 1./128); |
|
|
|
|
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer); |
|
|
|
|
printf("Loaded: %lf seconds\n", sec(clock()-time)); |
|
|
|
|
time=clock(); |
|
|
|
@ -265,8 +318,10 @@ void test_init(char *cfgfile) |
|
|
|
|
int i = 0; |
|
|
|
|
char *filename = "data/test.jpg"; |
|
|
|
|
|
|
|
|
|
image im = load_image_color(filename, 224, 224); |
|
|
|
|
z_normalize_image(im); |
|
|
|
|
image im = load_image_color(filename, 256, 256); |
|
|
|
|
//z_normalize_image(im);
|
|
|
|
|
translate_image(im, -128); |
|
|
|
|
scale_image(im, 1/128.); |
|
|
|
|
float *X = im.data; |
|
|
|
|
forward_network(net, X, 0, 1); |
|
|
|
|
for(i = 0; i < net.n; ++i){ |
|
|
|
@ -352,9 +407,9 @@ void train_cifar10() |
|
|
|
|
if(count%10 == 0){ |
|
|
|
|
float test_acc = network_accuracy(net, test); |
|
|
|
|
printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,sec(clock()-time)); |
|
|
|
|
char buff[256]; |
|
|
|
|
sprintf(buff, "unikitty/cifar10_%d.cfg", count); |
|
|
|
|
save_network(net, buff); |
|
|
|
|
//char buff[256];
|
|
|
|
|
//sprintf(buff, "unikitty/cifar10_%d.cfg", count);
|
|
|
|
|
//save_network(net, buff);
|
|
|
|
|
}else{ |
|
|
|
|
printf("%d: Loss: %f, Time: %lf seconds\n", count, loss, sec(clock()-time)); |
|
|
|
|
} |
|
|
|
@ -482,7 +537,7 @@ void visualize_cat() |
|
|
|
|
cvWaitKey(0); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void test_gpu_net() |
|
|
|
|
void test_correct_nist() |
|
|
|
|
{ |
|
|
|
|
srand(222222); |
|
|
|
|
network net = parse_network_cfg("cfg/nist.cfg"); |
|
|
|
@ -523,11 +578,12 @@ void test_correct_alexnet() |
|
|
|
|
clock_t time; |
|
|
|
|
int count = 0; |
|
|
|
|
network net; |
|
|
|
|
int imgs = net.batch; |
|
|
|
|
|
|
|
|
|
count = 0; |
|
|
|
|
srand(222222); |
|
|
|
|
net = parse_network_cfg("cfg/net.cfg"); |
|
|
|
|
int imgs = net.batch; |
|
|
|
|
|
|
|
|
|
count = 0; |
|
|
|
|
while(++count <= 5){ |
|
|
|
|
time=clock(); |
|
|
|
|
data train = load_data(paths, imgs, plist->size, labels, 1000, 256, 256); |
|
|
|
@ -624,9 +680,9 @@ int main(int argc, char **argv) |
|
|
|
|
} |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
if(0==strcmp(argv[1], "detection")) train_detection_net(); |
|
|
|
|
else if(0==strcmp(argv[1], "cifar")) train_cifar10(); |
|
|
|
|
if(0==strcmp(argv[1], "cifar")) train_cifar10(); |
|
|
|
|
else if(0==strcmp(argv[1], "test_correct")) test_correct_alexnet(); |
|
|
|
|
else if(0==strcmp(argv[1], "test_correct_nist")) test_correct_nist(); |
|
|
|
|
else if(0==strcmp(argv[1], "test")) test_imagenet(); |
|
|
|
|
else if(0==strcmp(argv[1], "server")) run_server(); |
|
|
|
|
|
|
|
|
@ -638,6 +694,7 @@ int main(int argc, char **argv) |
|
|
|
|
fprintf(stderr, "usage: %s <function> <filename>\n", argv[0]); |
|
|
|
|
return 0; |
|
|
|
|
} |
|
|
|
|
else if(0==strcmp(argv[1], "detection")) train_detection_net(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "nist")) train_nist(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "train")) train_imagenet(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "client")) train_imagenet_distributed(argv[2]); |
|
|
|
@ -646,6 +703,7 @@ int main(int argc, char **argv) |
|
|
|
|
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]); |
|
|
|
|
else if(0==strcmp(argv[1], "validetect")) validate_detection_net(argv[2]); |
|
|
|
|
else if(argc < 4){ |
|
|
|
|
fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]); |
|
|
|
|
return 0; |
|
|
|
|