mirror of https://github.com/AlexeyAB/darknet.git
parent
23955b9fa0
commit
16d06ec0db
30 changed files with 1453 additions and 148 deletions
@ -0,0 +1,95 @@ |
||||
#include "network.h" |
||||
#include "utils.h" |
||||
#include "parser.h" |
||||
#include "option_list.h" |
||||
#include "blas.h" |
||||
|
||||
#ifdef OPENCV |
||||
#include "opencv2/highgui/highgui_c.h" |
||||
#endif |
||||
|
||||
void train_cifar(char *cfgfile, char *weightfile) |
||||
{ |
||||
data_seed = time(0); |
||||
srand(time(0)); |
||||
float avg_loss = -1; |
||||
char *base = basecfg(cfgfile); |
||||
printf("%s\n", base); |
||||
network net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
||||
|
||||
char *backup_directory = "/home/pjreddie/backup/"; |
||||
int classes = 10; |
||||
int N = 50000; |
||||
|
||||
char **labels = get_labels("data/cifar/labels.txt"); |
||||
int epoch = (*net.seen)/N; |
||||
data train = load_all_cifar10(); |
||||
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ |
||||
clock_t time=clock(); |
||||
|
||||
float loss = train_network_sgd(net, train, 1); |
||||
if(avg_loss == -1) avg_loss = loss; |
||||
avg_loss = avg_loss*.9 + loss*.1; |
||||
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); |
||||
if(*net.seen/N > epoch){ |
||||
epoch = *net.seen/N; |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); |
||||
save_weights(net, buff); |
||||
} |
||||
if(get_current_batch(net)%100 == 0){ |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s.backup",backup_directory,base); |
||||
save_weights(net, buff); |
||||
} |
||||
} |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s.weights", backup_directory, base); |
||||
save_weights(net, buff); |
||||
|
||||
free_network(net); |
||||
free_ptrs((void**)labels, classes); |
||||
free(base); |
||||
free_data(train); |
||||
} |
||||
|
||||
void test_cifar(char *filename, char *weightfile) |
||||
{ |
||||
network net = parse_network_cfg(filename); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
srand(time(0)); |
||||
|
||||
clock_t time; |
||||
float avg_acc = 0; |
||||
float avg_top5 = 0; |
||||
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin"); |
||||
|
||||
time=clock(); |
||||
|
||||
float *acc = network_accuracies(net, test, 2); |
||||
avg_acc += acc[0]; |
||||
avg_top5 += acc[1]; |
||||
printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows); |
||||
free_data(test); |
||||
} |
||||
|
||||
void run_cifar(int argc, char **argv) |
||||
{ |
||||
if(argc < 4){ |
||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); |
||||
return; |
||||
} |
||||
|
||||
char *cfg = argv[3]; |
||||
char *weights = (argc > 4) ? argv[4] : 0; |
||||
if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights); |
||||
else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights); |
||||
} |
||||
|
||||
|
@ -0,0 +1,152 @@ |
||||
#include "network.h" |
||||
#include "detection_layer.h" |
||||
#include "cost_layer.h" |
||||
#include "utils.h" |
||||
#include "parser.h" |
||||
#include "box.h" |
||||
#include "image.h" |
||||
#include <sys/time.h> |
||||
|
||||
#define FRAMES 1 |
||||
|
||||
#ifdef OPENCV |
||||
#include "opencv2/highgui/highgui.hpp" |
||||
#include "opencv2/imgproc/imgproc.hpp" |
||||
void convert_coco_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness); |
||||
|
||||
extern char *coco_classes[]; |
||||
extern image coco_labels[]; |
||||
|
||||
static float **probs; |
||||
static box *boxes; |
||||
static network net; |
||||
static image in ; |
||||
static image in_s ; |
||||
static image det ; |
||||
static image det_s; |
||||
static image disp ; |
||||
static CvCapture * cap; |
||||
static float fps = 0; |
||||
static float demo_thresh = 0; |
||||
|
||||
static float *predictions[FRAMES]; |
||||
static int demo_index = 0; |
||||
static image images[FRAMES]; |
||||
static float *avg; |
||||
|
||||
void *fetch_in_thread_coco(void *ptr) |
||||
{ |
||||
in = get_image_from_stream(cap); |
||||
in_s = resize_image(in, net.w, net.h); |
||||
return 0; |
||||
} |
||||
|
||||
void *detect_in_thread_coco(void *ptr) |
||||
{ |
||||
float nms = .4; |
||||
|
||||
detection_layer l = net.layers[net.n-1]; |
||||
float *X = det_s.data; |
||||
float *prediction = network_predict(net, X); |
||||
|
||||
memcpy(predictions[demo_index], prediction, l.outputs*sizeof(float)); |
||||
mean_arrays(predictions, FRAMES, l.outputs, avg); |
||||
|
||||
free_image(det_s); |
||||
convert_coco_detections(avg, l.classes, l.n, l.sqrt, l.side, 1, 1, demo_thresh, probs, boxes, 0); |
||||
if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); |
||||
printf("\033[2J"); |
||||
printf("\033[1;1H"); |
||||
printf("\nFPS:%.0f\n",fps); |
||||
printf("Objects:\n\n"); |
||||
|
||||
images[demo_index] = det; |
||||
det = images[(demo_index + FRAMES/2 + 1)%FRAMES]; |
||||
demo_index = (demo_index + 1)%FRAMES; |
||||
|
||||
draw_detections(det, l.side*l.side*l.n, demo_thresh, boxes, probs, coco_classes, coco_labels, 80); |
||||
return 0; |
||||
} |
||||
|
||||
void demo_coco(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename) |
||||
{ |
||||
demo_thresh = thresh; |
||||
printf("YOLO demo\n"); |
||||
net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
set_batch_network(&net, 1); |
||||
|
||||
srand(2222222); |
||||
|
||||
if(filename){ |
||||
cap = cvCaptureFromFile(filename); |
||||
}else{ |
||||
cap = cvCaptureFromCAM(cam_index); |
||||
} |
||||
|
||||
if(!cap) error("Couldn't connect to webcam.\n"); |
||||
cvNamedWindow("YOLO", CV_WINDOW_NORMAL);
|
||||
cvResizeWindow("YOLO", 512, 512); |
||||
|
||||
detection_layer l = net.layers[net.n-1]; |
||||
int j; |
||||
|
||||
avg = (float *) calloc(l.outputs, sizeof(float)); |
||||
for(j = 0; j < FRAMES; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float)); |
||||
for(j = 0; j < FRAMES; ++j) images[j] = make_image(1,1,3); |
||||
|
||||
boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); |
||||
probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); |
||||
for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *)); |
||||
|
||||
pthread_t fetch_thread; |
||||
pthread_t detect_thread; |
||||
|
||||
fetch_in_thread_coco(0); |
||||
det = in; |
||||
det_s = in_s; |
||||
|
||||
fetch_in_thread_coco(0); |
||||
detect_in_thread_coco(0); |
||||
disp = det; |
||||
det = in; |
||||
det_s = in_s; |
||||
|
||||
for(j = 0; j < FRAMES/2; ++j){ |
||||
fetch_in_thread_coco(0); |
||||
detect_in_thread_coco(0); |
||||
disp = det; |
||||
det = in; |
||||
det_s = in_s; |
||||
} |
||||
|
||||
while(1){ |
||||
struct timeval tval_before, tval_after, tval_result; |
||||
gettimeofday(&tval_before, NULL); |
||||
if(pthread_create(&fetch_thread, 0, fetch_in_thread_coco, 0)) error("Thread creation failed"); |
||||
if(pthread_create(&detect_thread, 0, detect_in_thread_coco, 0)) error("Thread creation failed"); |
||||
show_image(disp, "YOLO"); |
||||
save_image(disp, "YOLO"); |
||||
free_image(disp); |
||||
cvWaitKey(10); |
||||
pthread_join(fetch_thread, 0); |
||||
pthread_join(detect_thread, 0); |
||||
|
||||
disp = det; |
||||
det = in; |
||||
det_s = in_s; |
||||
|
||||
gettimeofday(&tval_after, NULL); |
||||
timersub(&tval_after, &tval_before, &tval_result); |
||||
float curr = 1000000.f/((long int)tval_result.tv_usec); |
||||
fps = .9*fps + .1*curr; |
||||
} |
||||
} |
||||
#else |
||||
void demo_coco(char *cfgfile, char *weightfile, float thresh, int cam_index){ |
||||
fprintf(stderr, "YOLO-COCO demo needs OpenCV for webcam images.\n"); |
||||
} |
||||
#endif |
||||
|
@ -0,0 +1,277 @@ |
||||
#include "crnn_layer.h" |
||||
#include "convolutional_layer.h" |
||||
#include "utils.h" |
||||
#include "cuda.h" |
||||
#include "blas.h" |
||||
#include "gemm.h" |
||||
|
||||
#include <math.h> |
||||
#include <stdio.h> |
||||
#include <stdlib.h> |
||||
#include <string.h> |
||||
|
||||
static void increment_layer(layer *l, int steps) |
||||
{ |
||||
int num = l->outputs*l->batch*steps; |
||||
l->output += num; |
||||
l->delta += num; |
||||
l->x += num; |
||||
l->x_norm += num; |
||||
|
||||
#ifdef GPU |
||||
l->output_gpu += num; |
||||
l->delta_gpu += num; |
||||
l->x_gpu += num; |
||||
l->x_norm_gpu += num; |
||||
#endif |
||||
} |
||||
|
||||
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize) |
||||
{ |
||||
fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters); |
||||
batch = batch / steps; |
||||
layer l = {0}; |
||||
l.batch = batch; |
||||
l.type = CRNN; |
||||
l.steps = steps; |
||||
l.h = h; |
||||
l.w = w; |
||||
l.c = c; |
||||
l.out_h = h; |
||||
l.out_w = w; |
||||
l.out_c = output_filters; |
||||
l.inputs = h*w*c; |
||||
l.hidden = h * w * hidden_filters; |
||||
l.outputs = l.out_h * l.out_w * l.out_c; |
||||
|
||||
l.state = calloc(l.hidden*batch*(steps+1), sizeof(float)); |
||||
|
||||
l.input_layer = malloc(sizeof(layer)); |
||||
fprintf(stderr, "\t\t"); |
||||
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0); |
||||
l.input_layer->batch = batch; |
||||
|
||||
l.self_layer = malloc(sizeof(layer)); |
||||
fprintf(stderr, "\t\t"); |
||||
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0); |
||||
l.self_layer->batch = batch; |
||||
|
||||
l.output_layer = malloc(sizeof(layer)); |
||||
fprintf(stderr, "\t\t"); |
||||
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0); |
||||
l.output_layer->batch = batch; |
||||
|
||||
l.output = l.output_layer->output; |
||||
l.delta = l.output_layer->delta; |
||||
|
||||
#ifdef GPU |
||||
l.state_gpu = cuda_make_array(l.state, l.hidden*batch*(steps+1)); |
||||
l.output_gpu = l.output_layer->output_gpu; |
||||
l.delta_gpu = l.output_layer->delta_gpu; |
||||
#endif |
||||
|
||||
return l; |
||||
} |
||||
|
||||
void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay) |
||||
{ |
||||
update_convolutional_layer(*(l.input_layer), batch, learning_rate, momentum, decay); |
||||
update_convolutional_layer(*(l.self_layer), batch, learning_rate, momentum, decay); |
||||
update_convolutional_layer(*(l.output_layer), batch, learning_rate, momentum, decay); |
||||
} |
||||
|
||||
void forward_crnn_layer(layer l, network_state state) |
||||
{ |
||||
network_state s = {0}; |
||||
s.train = state.train; |
||||
int i; |
||||
layer input_layer = *(l.input_layer); |
||||
layer self_layer = *(l.self_layer); |
||||
layer output_layer = *(l.output_layer); |
||||
|
||||
fill_cpu(l.outputs * l.batch * l.steps, 0, output_layer.delta, 1); |
||||
fill_cpu(l.hidden * l.batch * l.steps, 0, self_layer.delta, 1); |
||||
fill_cpu(l.hidden * l.batch * l.steps, 0, input_layer.delta, 1); |
||||
if(state.train) fill_cpu(l.hidden * l.batch, 0, l.state, 1); |
||||
|
||||
for (i = 0; i < l.steps; ++i) { |
||||
s.input = state.input; |
||||
forward_convolutional_layer(input_layer, s); |
||||
|
||||
s.input = l.state; |
||||
forward_convolutional_layer(self_layer, s); |
||||
|
||||
float *old_state = l.state; |
||||
if(state.train) l.state += l.hidden*l.batch; |
||||
if(l.shortcut){ |
||||
copy_cpu(l.hidden * l.batch, old_state, 1, l.state, 1); |
||||
}else{ |
||||
fill_cpu(l.hidden * l.batch, 0, l.state, 1); |
||||
} |
||||
axpy_cpu(l.hidden * l.batch, 1, input_layer.output, 1, l.state, 1); |
||||
axpy_cpu(l.hidden * l.batch, 1, self_layer.output, 1, l.state, 1); |
||||
|
||||
s.input = l.state; |
||||
forward_convolutional_layer(output_layer, s); |
||||
|
||||
state.input += l.inputs*l.batch; |
||||
increment_layer(&input_layer, 1); |
||||
increment_layer(&self_layer, 1); |
||||
increment_layer(&output_layer, 1); |
||||
} |
||||
} |
||||
|
||||
void backward_crnn_layer(layer l, network_state state) |
||||
{ |
||||
network_state s = {0}; |
||||
s.train = state.train; |
||||
int i; |
||||
layer input_layer = *(l.input_layer); |
||||
layer self_layer = *(l.self_layer); |
||||
layer output_layer = *(l.output_layer); |
||||
|
||||
increment_layer(&input_layer, l.steps-1); |
||||
increment_layer(&self_layer, l.steps-1); |
||||
increment_layer(&output_layer, l.steps-1); |
||||
|
||||
l.state += l.hidden*l.batch*l.steps; |
||||
for (i = l.steps-1; i >= 0; --i) { |
||||
copy_cpu(l.hidden * l.batch, input_layer.output, 1, l.state, 1); |
||||
axpy_cpu(l.hidden * l.batch, 1, self_layer.output, 1, l.state, 1); |
||||
|
||||
s.input = l.state; |
||||
s.delta = self_layer.delta; |
||||
backward_convolutional_layer(output_layer, s); |
||||
|
||||
l.state -= l.hidden*l.batch; |
||||
/*
|
||||
if(i > 0){ |
||||
copy_cpu(l.hidden * l.batch, input_layer.output - l.hidden*l.batch, 1, l.state, 1); |
||||
axpy_cpu(l.hidden * l.batch, 1, self_layer.output - l.hidden*l.batch, 1, l.state, 1); |
||||
}else{ |
||||
fill_cpu(l.hidden * l.batch, 0, l.state, 1); |
||||
} |
||||
*/ |
||||
|
||||
s.input = l.state; |
||||
s.delta = self_layer.delta - l.hidden*l.batch; |
||||
if (i == 0) s.delta = 0; |
||||
backward_convolutional_layer(self_layer, s); |
||||
|
||||
copy_cpu(l.hidden*l.batch, self_layer.delta, 1, input_layer.delta, 1); |
||||
if (i > 0 && l.shortcut) axpy_cpu(l.hidden*l.batch, 1, self_layer.delta, 1, self_layer.delta - l.hidden*l.batch, 1); |
||||
s.input = state.input + i*l.inputs*l.batch; |
||||
if(state.delta) s.delta = state.delta + i*l.inputs*l.batch; |
||||
else s.delta = 0; |
||||
backward_convolutional_layer(input_layer, s); |
||||
|
||||
increment_layer(&input_layer, -1); |
||||
increment_layer(&self_layer, -1); |
||||
increment_layer(&output_layer, -1); |
||||
} |
||||
} |
||||
|
||||
#ifdef GPU |
||||
|
||||
void pull_crnn_layer(layer l) |
||||
{ |
||||
pull_convolutional_layer(*(l.input_layer)); |
||||
pull_convolutional_layer(*(l.self_layer)); |
||||
pull_convolutional_layer(*(l.output_layer)); |
||||
} |
||||
|
||||
void push_crnn_layer(layer l) |
||||
{ |
||||
push_convolutional_layer(*(l.input_layer)); |
||||
push_convolutional_layer(*(l.self_layer)); |
||||
push_convolutional_layer(*(l.output_layer)); |
||||
} |
||||
|
||||
void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) |
||||
{ |
||||
update_convolutional_layer_gpu(*(l.input_layer), batch, learning_rate, momentum, decay); |
||||
update_convolutional_layer_gpu(*(l.self_layer), batch, learning_rate, momentum, decay); |
||||
update_convolutional_layer_gpu(*(l.output_layer), batch, learning_rate, momentum, decay); |
||||
} |
||||
|
||||
void forward_crnn_layer_gpu(layer l, network_state state) |
||||
{ |
||||
network_state s = {0}; |
||||
s.train = state.train; |
||||
int i; |
||||
layer input_layer = *(l.input_layer); |
||||
layer self_layer = *(l.self_layer); |
||||
layer output_layer = *(l.output_layer); |
||||
|
||||
fill_ongpu(l.outputs * l.batch * l.steps, 0, output_layer.delta_gpu, 1); |
||||
fill_ongpu(l.hidden * l.batch * l.steps, 0, self_layer.delta_gpu, 1); |
||||
fill_ongpu(l.hidden * l.batch * l.steps, 0, input_layer.delta_gpu, 1); |
||||
if(state.train) fill_ongpu(l.hidden * l.batch, 0, l.state_gpu, 1); |
||||
|
||||
for (i = 0; i < l.steps; ++i) { |
||||
s.input = state.input; |
||||
forward_convolutional_layer_gpu(input_layer, s); |
||||
|
||||
s.input = l.state_gpu; |
||||
forward_convolutional_layer_gpu(self_layer, s); |
||||
|
||||
float *old_state = l.state_gpu; |
||||
if(state.train) l.state_gpu += l.hidden*l.batch; |
||||
if(l.shortcut){ |
||||
copy_ongpu(l.hidden * l.batch, old_state, 1, l.state_gpu, 1); |
||||
}else{ |
||||
fill_ongpu(l.hidden * l.batch, 0, l.state_gpu, 1); |
||||
} |
||||
axpy_ongpu(l.hidden * l.batch, 1, input_layer.output_gpu, 1, l.state_gpu, 1); |
||||
axpy_ongpu(l.hidden * l.batch, 1, self_layer.output_gpu, 1, l.state_gpu, 1); |
||||
|
||||
s.input = l.state_gpu; |
||||
forward_convolutional_layer_gpu(output_layer, s); |
||||
|
||||
state.input += l.inputs*l.batch; |
||||
increment_layer(&input_layer, 1); |
||||
increment_layer(&self_layer, 1); |
||||
increment_layer(&output_layer, 1); |
||||
} |
||||
} |
||||
|
||||
void backward_crnn_layer_gpu(layer l, network_state state) |
||||
{ |
||||
network_state s = {0}; |
||||
s.train = state.train; |
||||
int i; |
||||
layer input_layer = *(l.input_layer); |
||||
layer self_layer = *(l.self_layer); |
||||
layer output_layer = *(l.output_layer); |
||||
increment_layer(&input_layer, l.steps - 1); |
||||
increment_layer(&self_layer, l.steps - 1); |
||||
increment_layer(&output_layer, l.steps - 1); |
||||
l.state_gpu += l.hidden*l.batch*l.steps; |
||||
for (i = l.steps-1; i >= 0; --i) { |
||||
copy_ongpu(l.hidden * l.batch, input_layer.output_gpu, 1, l.state_gpu, 1); |
||||
axpy_ongpu(l.hidden * l.batch, 1, self_layer.output_gpu, 1, l.state_gpu, 1); |
||||
|
||||
s.input = l.state_gpu; |
||||
s.delta = self_layer.delta_gpu; |
||||
backward_convolutional_layer_gpu(output_layer, s); |
||||
|
||||
l.state_gpu -= l.hidden*l.batch; |
||||
|
||||
s.input = l.state_gpu; |
||||
s.delta = self_layer.delta_gpu - l.hidden*l.batch; |
||||
if (i == 0) s.delta = 0; |
||||
backward_convolutional_layer_gpu(self_layer, s); |
||||
|
||||
copy_ongpu(l.hidden*l.batch, self_layer.delta_gpu, 1, input_layer.delta_gpu, 1); |
||||
if (i > 0 && l.shortcut) axpy_ongpu(l.hidden*l.batch, 1, self_layer.delta_gpu, 1, self_layer.delta_gpu - l.hidden*l.batch, 1); |
||||
s.input = state.input + i*l.inputs*l.batch; |
||||
if(state.delta) s.delta = state.delta + i*l.inputs*l.batch; |
||||
else s.delta = 0; |
||||
backward_convolutional_layer_gpu(input_layer, s); |
||||
|
||||
increment_layer(&input_layer, -1); |
||||
increment_layer(&self_layer, -1); |
||||
increment_layer(&output_layer, -1); |
||||
} |
||||
} |
||||
#endif |
@ -0,0 +1,24 @@ |
||||
|
||||
#ifndef CRNN_LAYER_H |
||||
#define CRNN_LAYER_H |
||||
|
||||
#include "activations.h" |
||||
#include "layer.h" |
||||
#include "network.h" |
||||
|
||||
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize); |
||||
|
||||
void forward_crnn_layer(layer l, network_state state); |
||||
void backward_crnn_layer(layer l, network_state state); |
||||
void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay); |
||||
|
||||
#ifdef GPU |
||||
void forward_crnn_layer_gpu(layer l, network_state state); |
||||
void backward_crnn_layer_gpu(layer l, network_state state); |
||||
void update_crnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); |
||||
void push_crnn_layer(layer l); |
||||
void pull_crnn_layer(layer l); |
||||
#endif |
||||
|
||||
#endif |
||||
|
@ -0,0 +1,210 @@ |
||||
#include "network.h" |
||||
#include "cost_layer.h" |
||||
#include "utils.h" |
||||
#include "parser.h" |
||||
#include "blas.h" |
||||
|
||||
#ifdef OPENCV |
||||
#include "opencv2/highgui/highgui_c.h" |
||||
|
||||
void reconstruct_picture(network net, float *features, image recon, image update, float rate, float momentum, float lambda, int smooth_size, int iters); |
||||
|
||||
|
||||
typedef struct { |
||||
float *x; |
||||
float *y; |
||||
} float_pair; |
||||
|
||||
float_pair get_rnn_vid_data(network net, char **files, int n, int batch, int steps) |
||||
{ |
||||
int b; |
||||
assert(net.batch == steps + 1); |
||||
image out_im = get_network_image(net); |
||||
int output_size = out_im.w*out_im.h*out_im.c; |
||||
printf("%d %d %d\n", out_im.w, out_im.h, out_im.c); |
||||
float *feats = calloc(net.batch*batch*output_size, sizeof(float)); |
||||
for(b = 0; b < batch; ++b){ |
||||
int input_size = net.w*net.h*net.c; |
||||
float *input = calloc(input_size*net.batch, sizeof(float)); |
||||
char *filename = files[rand()%n]; |
||||
CvCapture *cap = cvCaptureFromFile(filename); |
||||
int frames = cvGetCaptureProperty(cap, CV_CAP_PROP_FRAME_COUNT); |
||||
int index = rand() % (frames - steps - 2); |
||||
if (frames < (steps + 4)){ |
||||
--b; |
||||
free(input); |
||||
continue; |
||||
} |
||||
|
||||
printf("frames: %d, index: %d\n", frames, index); |
||||
cvSetCaptureProperty(cap, CV_CAP_PROP_POS_FRAMES, index); |
||||
|
||||
int i; |
||||
for(i = 0; i < net.batch; ++i){ |
||||
IplImage* src = cvQueryFrame(cap); |
||||
image im = ipl_to_image(src); |
||||
rgbgr_image(im); |
||||
image re = resize_image(im, net.w, net.h); |
||||
//show_image(re, "loaded");
|
||||
//cvWaitKey(10);
|
||||
memcpy(input + i*input_size, re.data, input_size*sizeof(float)); |
||||
free_image(im); |
||||
free_image(re); |
||||
} |
||||
float *output = network_predict(net, input); |
||||
|
||||
free(input); |
||||
|
||||
for(i = 0; i < net.batch; ++i){ |
||||
memcpy(feats + (b + i*batch)*output_size, output + i*output_size, output_size*sizeof(float)); |
||||
} |
||||
|
||||
cvReleaseCapture(&cap); |
||||
} |
||||
|
||||
//printf("%d %d %d\n", out_im.w, out_im.h, out_im.c);
|
||||
float_pair p = {0}; |
||||
p.x = feats; |
||||
p.y = feats + output_size*batch; //+ out_im.w*out_im.h*out_im.c;
|
||||
|
||||
return p; |
||||
} |
||||
|
||||
|
||||
void train_vid_rnn(char *cfgfile, char *weightfile) |
||||
{ |
||||
char *train_videos = "data/vid/train.txt"; |
||||
char *backup_directory = "/home/pjreddie/backup/"; |
||||
srand(time(0)); |
||||
data_seed = time(0); |
||||
char *base = basecfg(cfgfile); |
||||
printf("%s\n", base); |
||||
float avg_loss = -1; |
||||
network net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
||||
int imgs = net.batch*net.subdivisions; |
||||
int i = *net.seen/imgs; |
||||
|
||||
list *plist = get_paths(train_videos); |
||||
int N = plist->size; |
||||
char **paths = (char **)list_to_array(plist); |
||||
clock_t time; |
||||
int steps = net.time_steps; |
||||
int batch = net.batch / net.time_steps; |
||||
|
||||
network extractor = parse_network_cfg("cfg/extractor.cfg"); |
||||
load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv"); |
||||
|
||||
while(get_current_batch(net) < net.max_batches){ |
||||
i += 1; |
||||
time=clock(); |
||||
float_pair p = get_rnn_vid_data(extractor, paths, N, batch, steps); |
||||
|
||||
float loss = train_network_datum(net, p.x, p.y) / (net.batch); |
||||
|
||||
|
||||
free(p.x); |
||||
if (avg_loss < 0) avg_loss = loss; |
||||
avg_loss = avg_loss*.9 + loss*.1; |
||||
|
||||
fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time)); |
||||
if(i%100==0){ |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); |
||||
save_weights(net, buff); |
||||
} |
||||
if(i%10==0){ |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s.backup", backup_directory, base); |
||||
save_weights(net, buff); |
||||
} |
||||
} |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s_final.weights", backup_directory, base); |
||||
save_weights(net, buff); |
||||
} |
||||
|
||||
|
||||
image save_reconstruction(network net, image *init, float *feat, char *name, int i) |
||||
{ |
||||
image recon; |
||||
if (init) { |
||||
recon = copy_image(*init); |
||||
} else { |
||||
recon = make_random_image(net.w, net.h, 3); |
||||
} |
||||
|
||||
image update = make_image(net.w, net.h, 3); |
||||
reconstruct_picture(net, feat, recon, update, .01, .9, .1, 2, 50); |
||||
char buff[256]; |
||||
sprintf(buff, "%s%d", name, i); |
||||
save_image(recon, buff); |
||||
free_image(update); |
||||
return recon; |
||||
} |
||||
|
||||
void generate_vid_rnn(char *cfgfile, char *weightfile) |
||||
{ |
||||
network extractor = parse_network_cfg("cfg/extractor.recon.cfg"); |
||||
load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv"); |
||||
|
||||
network net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
set_batch_network(&extractor, 1); |
||||
set_batch_network(&net, 1); |
||||
|
||||
int i; |
||||
CvCapture *cap = cvCaptureFromFile("/extra/vid/ILSVRC2015/Data/VID/snippets/val/ILSVRC2015_val_00007030.mp4"); |
||||
float *feat; |
||||
float *next; |
||||
image last; |
||||
for(i = 0; i < 25; ++i){ |
||||
image im = get_image_from_stream(cap); |
||||
image re = resize_image(im, extractor.w, extractor.h); |
||||
feat = network_predict(extractor, re.data); |
||||
if(i > 0){ |
||||
printf("%f %f\n", mean_array(feat, 14*14*512), variance_array(feat, 14*14*512)); |
||||
printf("%f %f\n", mean_array(next, 14*14*512), variance_array(next, 14*14*512)); |
||||
printf("%f\n", mse_array(feat, 14*14*512)); |
||||
axpy_cpu(14*14*512, -1, feat, 1, next, 1); |
||||
printf("%f\n", mse_array(next, 14*14*512)); |
||||
} |
||||
next = network_predict(net, feat); |
||||
|
||||
free_image(im); |
||||
|
||||
free_image(save_reconstruction(extractor, 0, feat, "feat", i)); |
||||
free_image(save_reconstruction(extractor, 0, next, "next", i)); |
||||
if (i==24) last = copy_image(re); |
||||
free_image(re); |
||||
} |
||||
for(i = 0; i < 30; ++i){ |
||||
next = network_predict(net, next); |
||||
image new = save_reconstruction(extractor, &last, next, "new", i); |
||||
free_image(last); |
||||
last = new; |
||||
} |
||||
} |
||||
|
||||
void run_vid_rnn(int argc, char **argv) |
||||
{ |
||||
if(argc < 4){ |
||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); |
||||
return; |
||||
} |
||||
|
||||
char *cfg = argv[3]; |
||||
char *weights = (argc > 4) ? argv[4] : 0; |
||||
//char *filename = (argc > 5) ? argv[5]: 0;
|
||||
if(0==strcmp(argv[2], "train")) train_vid_rnn(cfg, weights); |
||||
else if(0==strcmp(argv[2], "generate")) generate_vid_rnn(cfg, weights); |
||||
} |
||||
#else |
||||
void run_vid_rnn(int argc, char **argv){} |
||||
#endif |
||||
|
@ -0,0 +1,144 @@ |
||||
#include "network.h" |
||||
#include "utils.h" |
||||
#include "parser.h" |
||||
|
||||
#ifdef OPENCV |
||||
#include "opencv2/highgui/highgui_c.h" |
||||
#endif |
||||
|
||||
void train_tag(char *cfgfile, char *weightfile) |
||||
{ |
||||
data_seed = time(0); |
||||
srand(time(0)); |
||||
float avg_loss = -1; |
||||
char *base = basecfg(cfgfile); |
||||
char *backup_directory = "/home/pjreddie/backup/"; |
||||
printf("%s\n", base); |
||||
network net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
||||
int imgs = 1024; |
||||
list *plist = get_paths("/home/pjreddie/tag/train.list"); |
||||
char **paths = (char **)list_to_array(plist); |
||||
printf("%d\n", plist->size); |
||||
int N = plist->size; |
||||
clock_t time; |
||||
pthread_t load_thread; |
||||
data train; |
||||
data buffer; |
||||
|
||||
load_args args = {0}; |
||||
args.w = net.w; |
||||
args.h = net.h; |
||||
|
||||
args.min = net.w; |
||||
args.max = net.max_crop; |
||||
args.size = net.w; |
||||
|
||||
args.paths = paths; |
||||
args.classes = net.outputs; |
||||
args.n = imgs; |
||||
args.m = N; |
||||
args.d = &buffer; |
||||
args.type = TAG_DATA; |
||||
|
||||
fprintf(stderr, "%d classes\n", net.outputs); |
||||
|
||||
load_thread = load_data_in_thread(args); |
||||
int epoch = (*net.seen)/N; |
||||
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ |
||||
time=clock(); |
||||
pthread_join(load_thread, 0); |
||||
train = buffer; |
||||
|
||||
load_thread = load_data_in_thread(args); |
||||
printf("Loaded: %lf seconds\n", sec(clock()-time)); |
||||
time=clock(); |
||||
float loss = train_network(net, train); |
||||
if(avg_loss == -1) avg_loss = loss; |
||||
avg_loss = avg_loss*.9 + loss*.1; |
||||
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); |
||||
free_data(train); |
||||
if(*net.seen/N > epoch){ |
||||
epoch = *net.seen/N; |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); |
||||
save_weights(net, buff); |
||||
} |
||||
if(get_current_batch(net)%100 == 0){ |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s.backup",backup_directory,base); |
||||
save_weights(net, buff); |
||||
} |
||||
} |
||||
char buff[256]; |
||||
sprintf(buff, "%s/%s.weights", backup_directory, base); |
||||
save_weights(net, buff); |
||||
|
||||
pthread_join(load_thread, 0); |
||||
free_data(buffer); |
||||
free_network(net); |
||||
free_ptrs((void**)paths, plist->size); |
||||
free_list(plist); |
||||
free(base); |
||||
} |
||||
|
||||
void test_tag(char *cfgfile, char *weightfile, char *filename) |
||||
{ |
||||
network net = parse_network_cfg(cfgfile); |
||||
if(weightfile){ |
||||
load_weights(&net, weightfile); |
||||
} |
||||
set_batch_network(&net, 1); |
||||
srand(2222222); |
||||
int i = 0; |
||||
char **names = get_labels("data/tags.txt"); |
||||
clock_t time; |
||||
int indexes[10]; |
||||
char buff[256]; |
||||
char *input = buff; |
||||
while(1){ |
||||
if(filename){ |
||||
strncpy(input, filename, 256); |
||||
}else{ |
||||
printf("Enter Image Path: "); |
||||
fflush(stdout); |
||||
input = fgets(input, 256, stdin); |
||||
if(!input) return; |
||||
strtok(input, "\n"); |
||||
} |
||||
image im = load_image_color(input, net.w, net.h); |
||||
//resize_network(&net, im.w, im.h);
|
||||
printf("%d %d\n", im.w, im.h); |
||||
|
||||
float *X = im.data; |
||||
time=clock(); |
||||
float *predictions = network_predict(net, X); |
||||
top_predictions(net, 10, indexes); |
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); |
||||
for(i = 0; i < 10; ++i){ |
||||
int index = indexes[i]; |
||||
printf("%.1f%%: %s\n", predictions[index]*100, names[index]); |
||||
} |
||||
free_image(im); |
||||
if (filename) break; |
||||
} |
||||
} |
||||
|
||||
|
||||
void run_tag(int argc, char **argv) |
||||
{ |
||||
if(argc < 4){ |
||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); |
||||
return; |
||||
} |
||||
|
||||
char *cfg = argv[3]; |
||||
char *weights = (argc > 4) ? argv[4] : 0; |
||||
char *filename = (argc > 5) ? argv[5] : 0; |
||||
if(0==strcmp(argv[2], "train")) train_tag(cfg, weights); |
||||
else if(0==strcmp(argv[2], "test")) test_tag(cfg, weights, filename); |
||||
} |
||||
|
Loading…
Reference in new issue