diff --git a/Makefile b/Makefile index 07cf79fb..640f3082 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,21 @@ CC=gcc COMMON=-Wall `pkg-config --cflags opencv` UNAME = $(shell uname) +OPTS=-O3 ifeq ($(UNAME), Darwin) COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include LDFLAGS= -framework OpenCL else -COMMON+= -march=native -flto +OPTS+= -march=native -flto LDFLAGS= -lOpenCL endif -CFLAGS= $(COMMON) -Ofast +CFLAGS= $(COMMON) $(OPTS) #CFLAGS= $(COMMON) -O0 -g LDFLAGS+=`pkg-config --libs opencv` -lm VPATH=./src/ EXEC=cnn -OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o opencl.o gpu_gemm.o cpu_gemm.o +OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o opencl.o gpu_gemm.o cpu_gemm.o normalization_layer.o all: $(EXEC) diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 40d58584..6916eebc 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -320,11 +320,12 @@ image *visualize_convolutional_layer(convolutional_layer layer, char *window, im image *single_filters = weighted_sum_filters(layer, 0); show_images(single_filters, layer.n, window); - image delta = get_convolutional_delta(layer); + image delta = get_convolutional_image(layer); image dc = collapse_image_layers(delta, 1); char buff[256]; - sprintf(buff, "%s: Delta", window); - //show_image(dc, buff); + sprintf(buff, "%s: Output", window); + show_image(dc, buff); + save_image(dc, buff); free_image(dc); return single_filters; } diff --git a/src/image.c b/src/image.c index 5c138d33..453919fb 100644 --- a/src/image.c +++ b/src/image.c @@ -264,7 +264,7 @@ void add_into_image(image src, image dest, int h, int w) } } -void add_scalar_image(image m, float s) +void translate_image(image m, float s) { int i; for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s; @@ -645,15 +645,49 @@ void print_image(image m) for(i =0 ; i < m.h*m.w*m.c; ++i) printf("%lf, ", m.data[i]); printf("\n"); } +image collapse_images_vert(image *ims, int n) +{ + int color = 1; + int border = 1; + int h,w,c; + w = ims[0].w; + h = (ims[0].h + border) * n - border; + c = ims[0].c; + if(c != 3 || !color){ + w = (w+border)*c - border; + c = 1; + } -image collapse_images(image *ims, int n) + image filters = make_image(h,w,c); + int i,j; + for(i = 0; i < n; ++i){ + int h_offset = i*(ims[0].h+border); + image copy = copy_image(ims[i]); + //normalize_image(copy); + if(c == 3 && color){ + embed_image(copy, filters, h_offset, 0); + } + else{ + for(j = 0; j < copy.c; ++j){ + int w_offset = j*(ims[0].w+border); + image layer = get_image_layer(copy, j); + embed_image(layer, filters, h_offset, w_offset); + free_image(layer); + } + } + free_image(copy); + } + return filters; +} + +image collapse_images_horz(image *ims, int n) { int color = 1; int border = 1; int h,w,c; int size = ims[0].h; h = size; - w = (size + border) * n - border; + w = (ims[0].w + border) * n - border; c = ims[0].c; if(c != 3 || !color){ h = (h+border)*c - border; @@ -665,7 +699,7 @@ image collapse_images(image *ims, int n) for(i = 0; i < n; ++i){ int w_offset = i*(size+border); image copy = copy_image(ims[i]); - normalize_image(copy); + //normalize_image(copy); if(c == 3 && color){ embed_image(copy, filters, 0, w_offset); } @@ -684,11 +718,49 @@ image collapse_images(image *ims, int n) void show_images(image *ims, int n, char *window) { - image m = collapse_images(ims, n); + image m = collapse_images_vert(ims, n); + save_image(m, window); show_image(m, window); free_image(m); } +image grid_images(image **ims, int h, int w) +{ + int i; + image *rows = calloc(h, sizeof(image)); + for(i = 0; i < h; ++i){ + rows[i] = collapse_images_horz(ims[i], w); + } + image out = collapse_images_vert(rows, h); + for(i = 0; i < h; ++i){ + free_image(rows[i]); + } + free(rows); + return out; +} + +void test_grid() +{ + int i,j; + int num = 3; + int topk = 3; + image **vizs = calloc(num, sizeof(image*)); + for(i = 0; i < num; ++i){ + vizs[i] = calloc(topk, sizeof(image)); + for(j = 0; j < topk; ++j) vizs[i][j] = make_image(3,3,3); + } + image grid = grid_images(vizs, num, topk); + save_image(grid, "Test Grid"); + free_image(grid); +} + +void show_images_grid(image **ims, int h, int w, char *window) +{ + image out = grid_images(ims, h, w); + show_image(out, window); + free_image(out); +} + void free_image(image m) { free(m.data); diff --git a/src/image.h b/src/image.h index 9d064c36..fe257425 100644 --- a/src/image.h +++ b/src/image.h @@ -1,6 +1,7 @@ #ifndef IMAGE_H #define IMAGE_H + #include "opencv2/highgui/highgui_c.h" #include "opencv2/imgproc/imgproc_c.h" typedef struct { @@ -12,7 +13,7 @@ typedef struct { image image_distance(image a, image b); void scale_image(image m, float s); -void add_scalar_image(image m, float s); +void translate_image(image m, float s); void normalize_image(image p); void z_normalize_image(image p); void threshold_image(image p, float t); @@ -23,6 +24,8 @@ float avg_image_layer(image m, int l); void embed_image(image source, image dest, int h, int w); void add_into_image(image src, image dest, int h, int w); image collapse_image_layers(image source, int border); +image collapse_images_horz(image *ims, int n); +image collapse_images_vert(image *ims, int n); image get_sub_image(image m, int h, int w, int dh, int dw); void show_image(image p, char *name); @@ -30,6 +33,9 @@ void save_image(image p, char *name); void show_images(image *ims, int n, char *window); void show_image_layers(image p, char *name); void show_image_collapsed(image p, char *name); +void show_images_grid(image **ims, int h, int w, char *window); +void test_grid(); +image grid_images(image **ims, int h, int w); void print_image(image m); image make_image(int h, int w, int c); diff --git a/src/network.c b/src/network.c index edae3c7b..7d4b1fac 100644 --- a/src/network.c +++ b/src/network.c @@ -8,6 +8,7 @@ #include "convolutional_layer.h" //#include "old_conv.h" #include "maxpool_layer.h" +#include "normalization_layer.h" #include "softmax_layer.h" network make_network(int n, int batch) @@ -40,6 +41,17 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, int first) fprintf(fp, "data="); for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); + /* + int j,k; + for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); + for(i = 0; i < l->n; ++i){ + for(j = l->c-1; j >= 0; --j){ + for(k = 0; k < l->size*l->size; ++k){ + fprintf(fp, "%g,", l->filters[i*(l->c*l->size*l->size)+j*l->size*l->size+k]); + } + } + } + */ fprintf(fp, "\n\n"); } void print_connected_cfg(FILE *fp, connected_layer *l, int first) @@ -48,9 +60,9 @@ void print_connected_cfg(FILE *fp, connected_layer *l, int first) fprintf(fp, "[connected]\n"); if(first) fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs); fprintf(fp, "output=%d\n" - "activation=%s\n", - l->outputs, - get_activation_string(l->activation)); + "activation=%s\n", + l->outputs, + get_activation_string(l->activation)); fprintf(fp, "data="); for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]); for(i = 0; i < l->inputs*l->outputs; ++i) fprintf(fp, "%g,", l->weights[i]); @@ -61,13 +73,27 @@ void print_maxpool_cfg(FILE *fp, maxpool_layer *l, int first) { fprintf(fp, "[maxpool]\n"); if(first) fprintf(fp, "batch=%d\n" - "height=%d\n" - "width=%d\n" - "channels=%d\n", - l->batch,l->h, l->w, l->c); + "height=%d\n" + "width=%d\n" + "channels=%d\n", + l->batch,l->h, l->w, l->c); fprintf(fp, "stride=%d\n\n", l->stride); } +void print_normalization_cfg(FILE *fp, normalization_layer *l, int first) +{ + fprintf(fp, "[localresponsenormalization]\n"); + if(first) fprintf(fp, "batch=%d\n" + "height=%d\n" + "width=%d\n" + "channels=%d\n", + l->batch,l->h, l->w, l->c); + fprintf(fp, "size=%d\n" + "alpha=%g\n" + "beta=%g\n" + "kappa=%g\n\n", l->size, l->alpha, l->beta, l->kappa); +} + void print_softmax_cfg(FILE *fp, softmax_layer *l, int first) { fprintf(fp, "[softmax]\n"); @@ -88,6 +114,8 @@ void save_network(network net, char *filename) print_connected_cfg(fp, (connected_layer *)net.layers[i], i==0); else if(net.types[i] == MAXPOOL) print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], i==0); + else if(net.types[i] == NORMALIZATION) + print_normalization_cfg(fp, (normalization_layer *)net.layers[i], i==0); else if(net.types[i] == SOFTMAX) print_softmax_cfg(fp, (softmax_layer *)net.layers[i], i==0); } @@ -118,6 +146,11 @@ void forward_network(network net, float *input) forward_maxpool_layer(layer, input); input = layer.output; } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + forward_normalization_layer(layer, input); + input = layer.output; + } } } @@ -135,6 +168,9 @@ void update_network(network net, float step, float momentum, float decay) else if(net.types[i] == SOFTMAX){ //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; } + else if(net.types[i] == NORMALIZATION){ + //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; update_connected_layer(layer, step, momentum, decay); @@ -156,6 +192,9 @@ float *get_network_output_layer(network net, int i) } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; return layer.output; + } else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + return layer.output; } return 0; } @@ -233,6 +272,10 @@ float backward_network(network net, float *input, float *truth) maxpool_layer layer = *(maxpool_layer *)net.layers[i]; if(i != 0) backward_maxpool_layer(layer, prev_input, prev_delta); } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta); + } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; if(i != 0) backward_softmax_layer(layer, prev_input, prev_delta); @@ -272,7 +315,7 @@ float train_network_sgd(network net, data d, int n, float step, float momentum,f error += err; ++pos; } - + //printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]); //if((i+1)%10 == 0){ @@ -341,34 +384,34 @@ int get_network_output_size_layer(network net, int i) } /* -int resize_network(network net, int h, int w, int c) -{ - int i; - for (i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer *layer = (convolutional_layer *)net.layers[i]; - layer->h = h; - layer->w = w; - layer->c = c; - image output = get_convolutional_image(*layer); - h = output.h; - w = output.w; - c = output.c; - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer *layer = (maxpool_layer *)net.layers[i]; - layer->h = h; - layer->w = w; - layer->c = c; - image output = get_maxpool_image(*layer); - h = output.h; - w = output.w; - c = output.c; - } - } - return 0; -} -*/ + int resize_network(network net, int h, int w, int c) + { + int i; + for (i = 0; i < net.n; ++i){ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer *layer = (convolutional_layer *)net.layers[i]; + layer->h = h; + layer->w = w; + layer->c = c; + image output = get_convolutional_image(*layer); + h = output.h; + w = output.w; + c = output.c; + } + else if(net.types[i] == MAXPOOL){ + maxpool_layer *layer = (maxpool_layer *)net.layers[i]; + layer->h = h; + layer->w = w; + layer->c = c; + image output = get_maxpool_image(*layer); + h = output.h; + w = output.w; + c = output.c; + } + } + return 0; + } + */ int resize_network(network net, int h, int w, int c) { @@ -381,16 +424,21 @@ int resize_network(network net, int h, int w, int c) h = output.h; w = output.w; c = output.c; - } - else if(net.types[i] == MAXPOOL){ + }else if(net.types[i] == MAXPOOL){ maxpool_layer *layer = (maxpool_layer *)net.layers[i]; resize_maxpool_layer(layer, h, w, c); image output = get_maxpool_image(*layer); h = output.h; w = output.w; c = output.c; - } - else{ + }else if(net.types[i] == NORMALIZATION){ + normalization_layer *layer = (normalization_layer *)net.layers[i]; + resize_normalization_layer(layer, h, w, c); + image output = get_normalization_image(*layer); + h = output.h; + w = output.w; + c = output.c; + }else{ error("Cannot resize this type of layer"); } } @@ -413,6 +461,10 @@ image get_network_image_layer(network net, int i) maxpool_layer layer = *(maxpool_layer *)net.layers[i]; return get_maxpool_image(layer); } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + return get_normalization_image(layer); + } return make_empty_image(0,0,0); } @@ -437,6 +489,10 @@ void visualize_network(network net) convolutional_layer layer = *(convolutional_layer *)net.layers[i]; prev = visualize_convolutional_layer(layer, buff, prev); } + if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + visualize_normalization_layer(layer, buff); + } } } diff --git a/src/network.h b/src/network.h index 5acee61b..f6dac7e6 100644 --- a/src/network.h +++ b/src/network.h @@ -9,7 +9,8 @@ typedef enum { CONVOLUTIONAL, CONNECTED, MAXPOOL, - SOFTMAX + SOFTMAX, + NORMALIZATION } LAYER_TYPE; typedef struct { diff --git a/src/normalization_layer.c b/src/normalization_layer.c new file mode 100644 index 00000000..2d844e0e --- /dev/null +++ b/src/normalization_layer.c @@ -0,0 +1,96 @@ +#include "normalization_layer.h" +#include + +image get_normalization_image(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(h,w,c,layer.output); +} + +image get_normalization_delta(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(h,w,c,layer.delta); +} + +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa) +{ + fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size); + normalization_layer *layer = calloc(1, sizeof(normalization_layer)); + layer->batch = batch; + layer->h = h; + layer->w = w; + layer->c = c; + layer->kappa = kappa; + layer->size = size; + layer->alpha = alpha; + layer->beta = beta; + layer->output = calloc(h * w * c * batch, sizeof(float)); + layer->delta = calloc(h * w * c * batch, sizeof(float)); + layer->sums = calloc(h*w, sizeof(float)); + return layer; +} + +void resize_normalization_layer(normalization_layer *layer, int h, int w, int c) +{ + layer->h = h; + layer->w = w; + layer->c = c; + layer->output = realloc(layer->output, h * w * c * layer->batch * sizeof(float)); + layer->delta = realloc(layer->delta, h * w * c * layer->batch * sizeof(float)); + layer->sums = realloc(layer->sums, h*w * sizeof(float)); +} + +void add_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] += src[i]*src[i]; + } +} +void sub_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] -= src[i]*src[i]; + } +} + +void forward_normalization_layer(const normalization_layer layer, float *in) +{ + int i,j,k; + memset(layer.sums, 0, layer.h*layer.w*sizeof(float)); + int imsize = layer.h*layer.w; + for(j = 0; j < layer.size/2; ++j){ + if(j < layer.c) add_square_array(in+j*imsize, layer.sums, imsize); + } + for(k = 0; k < layer.c; ++k){ + int next = k+layer.size/2; + int prev = k-layer.size/2-1; + if(next < layer.c) add_square_array(in+next*imsize, layer.sums, imsize); + if(prev > 0) sub_square_array(in+prev*imsize, layer.sums, imsize); + for(i = 0; i < imsize; ++i){ + layer.output[k*imsize + i] = in[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta); + } + } +} + +void backward_normalization_layer(const normalization_layer layer, float *in, float *delta) +{ + //TODO! +} + +void visualize_normalization_layer(normalization_layer layer, char *window) +{ + image delta = get_normalization_image(layer); + image dc = collapse_image_layers(delta, 1); + char buff[256]; + sprintf(buff, "%s: Output", window); + show_image(dc, buff); + save_image(dc, buff); + free_image(dc); +} diff --git a/src/normalization_layer.h b/src/normalization_layer.h new file mode 100644 index 00000000..fcf8af11 --- /dev/null +++ b/src/normalization_layer.h @@ -0,0 +1,26 @@ +#ifndef NORMALIZATION_LAYER_H +#define NORMALIZATION_LAYER_H + +#include "image.h" + +typedef struct { + int batch; + int h,w,c; + int size; + float alpha; + float beta; + float kappa; + float *delta; + float *output; + float *sums; +} normalization_layer; + +image get_normalization_image(normalization_layer layer); +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa); +void resize_normalization_layer(normalization_layer *layer, int h, int w, int c); +void forward_normalization_layer(const normalization_layer layer, float *in); +void backward_normalization_layer(const normalization_layer layer, float *in, float *delta); +void visualize_normalization_layer(normalization_layer layer, char *window); + +#endif + diff --git a/src/parser.c b/src/parser.c index cf64b553..4aa0a79b 100644 --- a/src/parser.c +++ b/src/parser.c @@ -7,6 +7,7 @@ #include "convolutional_layer.h" #include "connected_layer.h" #include "maxpool_layer.h" +#include "normalization_layer.h" #include "softmax_layer.h" #include "list.h" #include "option_list.h" @@ -21,6 +22,7 @@ int is_convolutional(section *s); int is_connected(section *s); int is_maxpool(section *s); int is_softmax(section *s); +int is_normalization(section *s); list *read_cfg(char *filename); void free_section(section *s) @@ -152,6 +154,30 @@ maxpool_layer *parse_maxpool(list *options, network net, int count) return layer; } +normalization_layer *parse_normalization(list *options, network net, int count) +{ + int h,w,c; + int size = option_find_int(options, "size",1); + float alpha = option_find_float(options, "alpha", 0.); + float beta = option_find_float(options, "beta", 1.); + float kappa = option_find_float(options, "kappa", 1.); + if(count == 0){ + h = option_find_int(options, "height",1); + w = option_find_int(options, "width",1); + c = option_find_int(options, "channels",1); + net.batch = option_find_int(options, "batch",1); + }else{ + image m = get_network_image_layer(net, count-1); + h = m.h; + w = m.w; + c = m.c; + if(h == 0) error("Layer before convolutional layer must output image."); + } + normalization_layer *layer = make_normalization_layer(net.batch,h,w,c,size, alpha, beta, kappa); + option_unused(options); + return layer; +} + network parse_network_cfg(char *filename) { list *sections = read_cfg(filename); @@ -182,6 +208,11 @@ network parse_network_cfg(char *filename) net.types[count] = MAXPOOL; net.layers[count] = layer; net.batch = layer->batch; + }else if(is_normalization(s)){ + normalization_layer *layer = parse_normalization(options, net, count); + net.types[count] = NORMALIZATION; + net.layers[count] = layer; + net.batch = layer->batch; }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } @@ -216,6 +247,11 @@ int is_softmax(section *s) return (strcmp(s->type, "[soft]")==0 || strcmp(s->type, "[softmax]")==0); } +int is_normalization(section *s) +{ + return (strcmp(s->type, "[lrnorm]")==0 + || strcmp(s->type, "[localresponsenormalization]")==0); +} int read_option(char *s, list *options) { diff --git a/src/tests.c b/src/tests.c index 5d9136de..a6c3cd32 100644 --- a/src/tests.c +++ b/src/tests.c @@ -1,4 +1,5 @@ #include "connected_layer.h" + //#include "old_conv.h" #include "convolutional_layer.h" #include "maxpool_layer.h" @@ -223,7 +224,7 @@ void train_full() void test_visualize() { - network net = parse_network_cfg("cfg/imagenet.cfg"); + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); srand(2222222); visualize_network(net); cvWaitKey(0); @@ -445,6 +446,12 @@ void test_im2row() } } +void flip_network() +{ + network net = parse_network_cfg("cfg/voc_imagenet_orig.cfg"); + save_network(net, "cfg/voc_imagenet_rev.cfg"); +} + void train_VOC() { network net = parse_network_cfg("cfg/voc_start.cfg"); @@ -498,6 +505,7 @@ image features_output_size(network net, IplImage *src, int outh, int outw) IplImage *sized = cvCreateImage(cvSize(w,h), src->depth, src->nChannels); cvResize(src, sized, CV_INTER_LINEAR); image im = ipl_to_image(sized); + normalize_array(im.data, im.h*im.w*im.c); resize_network(net, im.h, im.w, im.c); forward_network(net, im.data); image out = get_network_image_layer(net, 6); @@ -523,6 +531,69 @@ void features_VOC_image_size(char *image_path, int h, int w) free_image(out); cvReleaseImage(&src); } +void visualize_imagenet_topk(char *filename) +{ + int i,j,k,l; + int topk = 10; + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + list *plist = get_paths(filename); + node *n = plist->front; + int h = voc_size(1), w = voc_size(1); + int num = get_network_image(net).c; + image **vizs = calloc(num, sizeof(image*)); + float **score = calloc(num, sizeof(float *)); + for(i = 0; i < num; ++i){ + vizs[i] = calloc(topk, sizeof(image)); + for(j = 0; j < topk; ++j) vizs[i][j] = make_image(h,w,3); + score[i] = calloc(topk, sizeof(float)); + } + + while(n){ + char *image_path = (char *)n->val; + image im = load_image(image_path, 0, 0); + n = n->next; + if(im.h < 200 || im.w < 200) continue; + printf("Processing %dx%d image\n", im.h, im.w); + resize_network(net, im.h, im.w, im.c); + //scale_image(im, 1./255); + translate_image(im, -144); + forward_network(net, im.data); + image out = get_network_image(net); + + int dh = (im.h - h)/h; + int dw = (im.w - w)/w; + for(i = 0; i < out.h; ++i){ + for(j = 0; j < out.w; ++j){ + image sub = get_sub_image(im, dh*i, dw*j, h, w); + for(k = 0; k < out.c; ++k){ + float val = get_pixel(out, i, j, k); + //printf("%f, ", val); + image sub_c = copy_image(sub); + for(l = 0; l < topk; ++l){ + if(val > score[k][l]){ + float swap = score[k][l]; + score[k][l] = val; + val = swap; + + image swapi = vizs[k][l]; + vizs[k][l] = sub_c; + sub_c = swapi; + } + } + free_image(sub_c); + } + free_image(sub); + } + } + free_image(im); + //printf("\n"); + image grid = grid_images(vizs, num, topk); + show_image(grid, "IMAGENET Visualization"); + save_image(grid, "IMAGENET Grid"); + free_image(grid); + } + //cvWaitKey(0); +} void visualize_imagenet_features(char *filename) { @@ -566,6 +637,20 @@ void visualize_imagenet_features(char *filename) cvWaitKey(0); } +void visualize_cat() +{ + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + image im = load_image("data/cat.png", 0, 0); + printf("Processing %dx%d image\n", im.h, im.w); + resize_network(net, im.h, im.w, im.c); + forward_network(net, im.data); + + image out = get_network_image(net); + visualize_network(net); + cvWaitKey(1000); + cvWaitKey(0); +} + void features_VOC_image(char *image_file, char *image_dir, char *out_dir) { int i,j; @@ -693,7 +778,10 @@ int main(int argc, char *argv[]) //features_VOC_image(argv[1], argv[2], argv[3]); //features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3])); //visualize_imagenet_features("data/assira/train.list"); - visualize_imagenet_features("data/VOC2011.list"); + visualize_imagenet_topk("data/VOC2011.list"); + //visualize_cat(); + //flip_network(); + //test_visualize(); fprintf(stderr, "Success!\n"); //test_random_preprocess(); //test_random_classify();