diff --git a/Makefile b/Makefile index dc08b468..fda7d888 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ CC=gcc COMMON=-Wall `pkg-config --cflags opencv` -CFLAGS= $(COMMON) -Ofast -ffast-math -flto +CFLAGS= $(COMMON) -O3 -ffast-math -flto UNAME = $(shell uname) ifeq ($(UNAME), Darwin) COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include diff --git a/nist_basic.cfg b/nist_basic.cfg index f5ea0a38..71427358 100644 --- a/nist_basic.cfg +++ b/nist_basic.cfg @@ -3,7 +3,7 @@ width=28 height=28 channels=1 filters=20 -size=5 +size=11 stride=1 activation=linear diff --git a/src/activations.c b/src/activations.c index b8bb79d9..cc923d0e 100644 --- a/src/activations.c +++ b/src/activations.c @@ -15,7 +15,7 @@ ACTIVATION get_activation(char *s) return RELU; } -double activate(double x, ACTIVATION a){ +float activate(float x, ACTIVATION a){ switch(a){ case LINEAR: return x; @@ -30,7 +30,7 @@ double activate(double x, ACTIVATION a){ } return 0; } -double gradient(double x, ACTIVATION a){ +float gradient(float x, ACTIVATION a){ switch(a){ case LINEAR: return 1; diff --git a/src/activations.h b/src/activations.h index 889453f6..fb2c54f4 100644 --- a/src/activations.h +++ b/src/activations.h @@ -7,8 +7,8 @@ typedef enum{ ACTIVATION get_activation(char *s); -double activate(double x, ACTIVATION a); -double gradient(double x, ACTIVATION a); +float activate(float x, ACTIVATION a); +float gradient(float x, ACTIVATION a); #endif diff --git a/src/connected_layer.c b/src/connected_layer.c index 6871b2ee..5f6631cb 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -15,19 +15,19 @@ connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activa layer->inputs = inputs; layer->outputs = outputs; - layer->output = calloc(outputs, sizeof(double*)); - layer->delta = calloc(outputs, sizeof(double*)); + layer->output = calloc(outputs, sizeof(float*)); + layer->delta = calloc(outputs, sizeof(float*)); - layer->weight_updates = calloc(inputs*outputs, sizeof(double)); - layer->weight_momentum = calloc(inputs*outputs, sizeof(double)); - layer->weights = calloc(inputs*outputs, sizeof(double)); - double scale = 2./inputs; + layer->weight_updates = calloc(inputs*outputs, sizeof(float)); + layer->weight_momentum = calloc(inputs*outputs, sizeof(float)); + layer->weights = calloc(inputs*outputs, sizeof(float)); + float scale = 2./inputs; for(i = 0; i < inputs*outputs; ++i) layer->weights[i] = rand_normal()*scale; - layer->bias_updates = calloc(outputs, sizeof(double)); - layer->bias_momentum = calloc(outputs, sizeof(double)); - layer->biases = calloc(outputs, sizeof(double)); + layer->bias_updates = calloc(outputs, sizeof(float)); + layer->bias_momentum = calloc(outputs, sizeof(float)); + layer->biases = calloc(outputs, sizeof(float)); for(i = 0; i < outputs; ++i) //layer->biases[i] = rand_normal()*scale + scale; layer->biases[i] = 0; @@ -36,7 +36,7 @@ connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activa return layer; } -void update_connected_layer(connected_layer layer, double step, double momentum, double decay) +void update_connected_layer(connected_layer layer, float step, float momentum, float decay) { int i; for(i = 0; i < layer.outputs; ++i){ @@ -47,27 +47,27 @@ void update_connected_layer(connected_layer layer, double step, double momentum, layer.weight_momentum[i] = step*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i]; layer.weights[i] += layer.weight_momentum[i]; } - memset(layer.bias_updates, 0, layer.outputs*sizeof(double)); - memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double)); + memset(layer.bias_updates, 0, layer.outputs*sizeof(float)); + memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float)); } -void forward_connected_layer(connected_layer layer, double *input) +void forward_connected_layer(connected_layer layer, float *input) { int i; - memcpy(layer.output, layer.biases, layer.outputs*sizeof(double)); + memcpy(layer.output, layer.biases, layer.outputs*sizeof(float)); int m = 1; int k = layer.inputs; int n = layer.outputs; - double *a = input; - double *b = layer.weights; - double *c = layer.output; + float *a = input; + float *b = layer.weights; + float *c = layer.output; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); for(i = 0; i < layer.outputs; ++i){ layer.output[i] = activate(layer.output[i], layer.activation); } } -void learn_connected_layer(connected_layer layer, double *input) +void learn_connected_layer(connected_layer layer, float *input) { int i; for(i = 0; i < layer.outputs; ++i){ @@ -77,28 +77,28 @@ void learn_connected_layer(connected_layer layer, double *input) int m = layer.inputs; int k = 1; int n = layer.outputs; - double *a = input; - double *b = layer.delta; - double *c = layer.weight_updates; + float *a = input; + float *b = layer.delta; + float *c = layer.weight_updates; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); } -void backward_connected_layer(connected_layer layer, double *input, double *delta) +void backward_connected_layer(connected_layer layer, float *input, float *delta) { - memset(delta, 0, layer.inputs*sizeof(double)); + memset(delta, 0, layer.inputs*sizeof(float)); int m = layer.inputs; int k = layer.outputs; int n = 1; - double *a = layer.weights; - double *b = layer.delta; - double *c = delta; + float *a = layer.weights; + float *b = layer.delta; + float *c = delta; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); } /* - void forward_connected_layer(connected_layer layer, double *input) + void forward_connected_layer(connected_layer layer, float *input) { int i, j; for(i = 0; i < layer.outputs; ++i){ @@ -109,7 +109,7 @@ void backward_connected_layer(connected_layer layer, double *input, double *delt layer.output[i] = activate(layer.output[i], layer.activation); } } - void learn_connected_layer(connected_layer layer, double *input) + void learn_connected_layer(connected_layer layer, float *input) { int i, j; for(i = 0; i < layer.outputs; ++i){ @@ -120,7 +120,7 @@ void backward_connected_layer(connected_layer layer, double *input, double *delt } } } - void backward_connected_layer(connected_layer layer, double *input, double *delta) + void backward_connected_layer(connected_layer layer, float *input, float *delta) { int i, j; diff --git a/src/connected_layer.h b/src/connected_layer.h index 05fb2616..ce0181d4 100644 --- a/src/connected_layer.h +++ b/src/connected_layer.h @@ -6,17 +6,17 @@ typedef struct{ int inputs; int outputs; - double *weights; - double *biases; + float *weights; + float *biases; - double *weight_updates; - double *bias_updates; + float *weight_updates; + float *bias_updates; - double *weight_momentum; - double *bias_momentum; + float *weight_momentum; + float *bias_momentum; - double *output; - double *delta; + float *output; + float *delta; ACTIVATION activation; @@ -24,10 +24,10 @@ typedef struct{ connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activation); -void forward_connected_layer(connected_layer layer, double *input); -void backward_connected_layer(connected_layer layer, double *input, double *delta); -void learn_connected_layer(connected_layer layer, double *input); -void update_connected_layer(connected_layer layer, double step, double momentum, double decay); +void forward_connected_layer(connected_layer layer, float *input); +void backward_connected_layer(connected_layer layer, float *input, float *delta); +void learn_connected_layer(connected_layer layer, float *input); +void update_connected_layer(connected_layer layer, float step, float momentum, float decay); #endif diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 53eb7bf1..cdfe9e1a 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -9,7 +9,7 @@ image get_convolutional_image(convolutional_layer layer) h = layer.out_h; w = layer.out_w; c = layer.n; - return double_to_image(h,w,c,layer.output); + return float_to_image(h,w,c,layer.output); } image get_convolutional_delta(convolutional_layer layer) @@ -18,7 +18,7 @@ image get_convolutional_delta(convolutional_layer layer) h = layer.out_h; w = layer.out_w; c = layer.n; - return double_to_image(h,w,c,layer.delta); + return float_to_image(h,w,c,layer.delta); } convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation) @@ -34,14 +34,14 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si layer->stride = stride; layer->size = size; - layer->filters = calloc(c*n*size*size, sizeof(double)); - layer->filter_updates = calloc(c*n*size*size, sizeof(double)); - layer->filter_momentum = calloc(c*n*size*size, sizeof(double)); + layer->filters = calloc(c*n*size*size, sizeof(float)); + layer->filter_updates = calloc(c*n*size*size, sizeof(float)); + layer->filter_momentum = calloc(c*n*size*size, sizeof(float)); - layer->biases = calloc(n, sizeof(double)); - layer->bias_updates = calloc(n, sizeof(double)); - layer->bias_momentum = calloc(n, sizeof(double)); - double scale = 2./(size*size); + layer->biases = calloc(n, sizeof(float)); + layer->bias_updates = calloc(n, sizeof(float)); + layer->bias_momentum = calloc(n, sizeof(float)); + float scale = 2./(size*size); for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = rand_normal()*scale; for(i = 0; i < n; ++i){ //layer->biases[i] = rand_normal()*scale + scale; @@ -50,9 +50,9 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si out_h = (h-size)/stride + 1; out_w = (w-size)/stride + 1; - layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(double)); - layer->output = calloc(out_h * out_w * n, sizeof(double)); - layer->delta = calloc(out_h * out_w * n, sizeof(double)); + layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); + layer->output = calloc(out_h * out_w * n, sizeof(float)); + layer->delta = calloc(out_h * out_w * n, sizeof(float)); layer->activation = activation; layer->out_h = out_h; layer->out_w = out_w; @@ -63,18 +63,18 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si return layer; } -void forward_convolutional_layer(const convolutional_layer layer, double *in) +void forward_convolutional_layer(const convolutional_layer layer, float *in) { int m = layer.n; int k = layer.size*layer.size*layer.c; int n = ((layer.h-layer.size)/layer.stride + 1)* ((layer.w-layer.size)/layer.stride + 1); - memset(layer.output, 0, m*n*sizeof(double)); + memset(layer.output, 0, m*n*sizeof(float)); - double *a = layer.filters; - double *b = layer.col_image; - double *c = layer.output; + float *a = layer.filters; + float *b = layer.col_image; + float *c = layer.output; im2col_cpu(in, layer.c, layer.h, layer.w, layer.size, layer.stride, b); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); @@ -94,7 +94,7 @@ void learn_bias_convolutional_layer(convolutional_layer layer) int i,j; int size = layer.out_h*layer.out_w; for(i = 0; i < layer.n; ++i){ - double sum = 0; + float sum = 0; for(j = 0; j < size; ++j){ sum += layer.delta[j+i*size]; } @@ -111,14 +111,33 @@ void learn_convolutional_layer(convolutional_layer layer) int k = ((layer.h-layer.size)/layer.stride + 1)* ((layer.w-layer.size)/layer.stride + 1); - double *a = layer.delta; - double *b = layer.col_image; - double *c = layer.filter_updates; + float *a = layer.delta; + float *b = layer.col_image; + float *c = layer.filter_updates; gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); } -void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay) +void backward_convolutional_layer(convolutional_layer layer, float *delta) +{ + int m = layer.size*layer.size*layer.c; + int k = layer.n; + int n = ((layer.h-layer.size)/layer.stride + 1)* + ((layer.w-layer.size)/layer.stride + 1); + + float *a = layer.filters; + float *b = layer.delta; + float *c = layer.col_image; + + + memset(c, 0, m*n*sizeof(float)); + gemm(1,0,m,n,k,1,a,m,b,n,1,c,n); + + memset(delta, 0, layer.h*layer.w*layer.c*sizeof(float)); + col2im_cpu(c, layer.c, layer.h, layer.w, layer.size, layer.stride, delta); +} + +void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay) { int i; int size = layer.size*layer.size*layer.c*layer.n; @@ -133,9 +152,9 @@ void update_convolutional_layer(convolutional_layer layer, double step, double m } /* -void backward_convolutional_layer2(convolutional_layer layer, double *input, double *delta) +void backward_convolutional_layer2(convolutional_layer layer, float *input, float *delta) { - image in_delta = double_to_image(layer.h, layer.w, layer.c, delta); + image in_delta = float_to_image(layer.h, layer.w, layer.c, delta); image out_delta = get_convolutional_delta(layer); int i,j; for(i = 0; i < layer.n; ++i){ @@ -156,10 +175,10 @@ void backward_convolutional_layer2(convolutional_layer layer, double *input, dou } -void learn_convolutional_layer(convolutional_layer layer, double *input) +void learn_convolutional_layer(convolutional_layer layer, float *input) { int i; - image in_image = double_to_image(layer.h, layer.w, layer.c, input); + image in_image = float_to_image(layer.h, layer.w, layer.c, input); image out_delta = get_convolutional_delta(layer); gradient_delta_convolutional_layer(layer); for(i = 0; i < layer.n; ++i){ @@ -168,7 +187,7 @@ void learn_convolutional_layer(convolutional_layer layer, double *input) } } -void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay) +void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay) { int i,j; for(i = 0; i < layer.n; ++i){ @@ -190,21 +209,28 @@ void update_convolutional_layer(convolutional_layer layer, double step, double m void test_convolutional_layer() { convolutional_layer l = *make_convolutional_layer(4,4,1,1,3,1,LINEAR); - double input[] = {1,2,3,4, + float input[] = {1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16}; - double filter[] = {.5, 0, .3, + float filter[] = {.5, 0, .3, 0 , 1, 0, .2 , 0, 1}; - double delta[] = {1, 2, + float delta[] = {1, 2, 3, 4}; + float in_delta[] = {.5,1,.3,.6, + 5,6,7,8, + 9,10,11,12, + 13,14,15,16}; l.filters = filter; forward_convolutional_layer(l, input); l.delta = delta; learn_convolutional_layer(l); - image filter_updates = double_to_image(3,3,1,l.filter_updates); + image filter_updates = float_to_image(3,3,1,l.filter_updates); print_image(filter_updates); + printf("Delta:\n"); + backward_convolutional_layer(l, in_delta); + pm(4,4,in_delta); } image get_convolutional_filter(convolutional_layer layer, int i) @@ -212,7 +238,7 @@ image get_convolutional_filter(convolutional_layer layer, int i) int h = layer.size; int w = layer.size; int c = layer.c; - return double_to_image(h,w,c,layer.filters+i*h*w*c); + return float_to_image(h,w,c,layer.filters+i*h*w*c); } void visualize_convolutional_layer(convolutional_layer layer, char *window) diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index e2e6cdc4..c4de24e5 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -10,28 +10,28 @@ typedef struct { int n; int size; int stride; - double *filters; - double *filter_updates; - double *filter_momentum; + float *filters; + float *filter_updates; + float *filter_momentum; - double *biases; - double *bias_updates; - double *bias_momentum; + float *biases; + float *bias_updates; + float *bias_momentum; - double *col_image; - double *delta; - double *output; + float *col_image; + float *delta; + float *output; ACTIVATION activation; } convolutional_layer; convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation); -void forward_convolutional_layer(const convolutional_layer layer, double *in); +void forward_convolutional_layer(const convolutional_layer layer, float *in); void learn_convolutional_layer(convolutional_layer layer); -void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay); +void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay); void visualize_convolutional_layer(convolutional_layer layer, char *window); -//void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta); +void backward_convolutional_layer(convolutional_layer layer, float *delta); //void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer); //void visualize_convolutional_filters(convolutional_layer layer, char *window); diff --git a/src/data.c b/src/data.c index 0b396d70..2c5932b0 100644 --- a/src/data.c +++ b/src/data.c @@ -19,10 +19,10 @@ list *get_paths(char *filename) return lines; } -void fill_truth(char *path, char **labels, int k, double *truth) +void fill_truth(char *path, char **labels, int k, float *truth) { int i; - memset(truth, 0, k*sizeof(double)); + memset(truth, 0, k*sizeof(float)); for(i = 0; i < k; ++i){ if(strstr(path, labels[i])){ truth[i] = 1; @@ -36,7 +36,7 @@ data load_data_image_paths(char **paths, int n, char **labels, int k) data d; d.shallow = 0; d.X.rows = n; - d.X.vals = calloc(d.X.rows, sizeof(double*)); + d.X.vals = calloc(d.X.rows, sizeof(float*)); d.y = make_matrix(n, k); for(i = 0; i < n; ++i){ @@ -106,8 +106,8 @@ data load_categorical_data_csv(char *filename, int target, int k) data d; d.shallow = 0; matrix X = csv_to_matrix(filename); - double *truth_1d = pop_column(&X, target); - double **truth = one_hot_encode(truth_1d, X.rows, k); + float *truth_1d = pop_column(&X, target); + float **truth = one_hot_encode(truth_1d, X.rows, k); matrix y; y.rows = X.rows; y.cols = k; @@ -123,7 +123,7 @@ void randomize_data(data d) int i; for(i = d.X.rows-1; i > 0; --i){ int index = rand()%i; - double *swap = d.X.vals[index]; + float *swap = d.X.vals[index]; d.X.vals[index] = d.X.vals[i]; d.X.vals[i] = swap; @@ -156,10 +156,10 @@ data *split_data(data d, int part, int total) train.X.cols = test.X.cols = d.X.cols; train.y.cols = test.y.cols = d.y.cols; - train.X.vals = calloc(train.X.rows, sizeof(double*)); - test.X.vals = calloc(test.X.rows, sizeof(double*)); - train.y.vals = calloc(train.y.rows, sizeof(double*)); - test.y.vals = calloc(test.y.rows, sizeof(double*)); + train.X.vals = calloc(train.X.rows, sizeof(float*)); + test.X.vals = calloc(test.X.rows, sizeof(float*)); + train.y.vals = calloc(train.y.rows, sizeof(float*)); + test.y.vals = calloc(test.y.rows, sizeof(float*)); for(i = 0; i < start; ++i){ train.X.vals[i] = d.X.vals[i]; diff --git a/src/image.c b/src/image.c index df8e1b8f..62ee5f7e 100644 --- a/src/image.c +++ b/src/image.c @@ -16,7 +16,7 @@ void embed_image(image source, image dest, int h, int w) for(k = 0; k < source.c; ++k){ for(i = 0; i < source.h; ++i){ for(j = 0; j < source.w; ++j){ - double val = get_pixel(source, i,j,k); + float val = get_pixel(source, i,j,k); set_pixel(dest, h+i, w+j, k, val); } } @@ -45,14 +45,14 @@ void z_normalize_image(image p) void normalize_image(image p) { - double *min = calloc(p.c, sizeof(double)); - double *max = calloc(p.c, sizeof(double)); + float *min = calloc(p.c, sizeof(float)); + float *max = calloc(p.c, sizeof(float)); int i,j; for(i = 0; i < p.c; ++i) min[i] = max[i] = p.data[i*p.h*p.w]; for(j = 0; j < p.c; ++j){ for(i = 0; i < p.h*p.w; ++i){ - double v = p.data[i+j*p.h*p.w]; + float v = p.data[i+j*p.h*p.w]; if(v < min[j]) min[j] = v; if(v > max[j]) max[j] = v; } @@ -72,17 +72,17 @@ void normalize_image(image p) free(max); } -double avg_image_layer(image m, int l) +float avg_image_layer(image m, int l) { int i; - double sum = 0; + float sum = 0; for(i = 0; i < m.h*m.w; ++i){ sum += m.data[l*m.h*m.w + i]; } return sum/(m.h*m.w); } -void threshold_image(image p, double t) +void threshold_image(image p, float t) { int i; for(i = 0; i < p.w*p.h*p.c; ++i){ @@ -93,8 +93,8 @@ void threshold_image(image p, double t) image copy_image(image p) { image copy = p; - copy.data = calloc(p.h*p.w*p.c, sizeof(double)); - memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(double)); + copy.data = calloc(p.h*p.w*p.c, sizeof(float)); + memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(float)); return copy; } @@ -168,11 +168,11 @@ image make_empty_image(int h, int w, int c) image make_image(int h, int w, int c) { image out = make_empty_image(h,w,c); - out.data = calloc(h*w*c, sizeof(double)); + out.data = calloc(h*w*c, sizeof(float)); return out; } -image double_to_image(int h, int w, int c, double *data) +image float_to_image(int h, int w, int c, float *data) { image out = make_empty_image(h,w,c); out.data = data; @@ -181,12 +181,12 @@ image double_to_image(int h, int w, int c, double *data) void zero_image(image m) { - memset(m.data, 0, m.h*m.w*m.c*sizeof(double)); + memset(m.data, 0, m.h*m.w*m.c*sizeof(float)); } void zero_channel(image m, int c) { - memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(double)); + memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(float)); } void rotate_image(image m) @@ -194,7 +194,7 @@ void rotate_image(image m) int i,j; for(j = 0; j < m.c; ++j){ for(i = 0; i < m.h*m.w/2; ++i){ - double swap = m.data[j*m.h*m.w + i]; + float swap = m.data[j*m.h*m.w + i]; m.data[j*m.h*m.w + i] = m.data[j*m.h*m.w + (m.h*m.w-1 - i)]; m.data[j*m.h*m.w + (m.h*m.w-1 - i)] = swap; } @@ -212,19 +212,19 @@ image make_random_image(int h, int w, int c) return out; } -void add_scalar_image(image m, double s) +void add_scalar_image(image m, float s) { int i; for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s; } -void scale_image(image m, double s) +void scale_image(image m, float s) { int i; for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s; } -image make_random_kernel(int size, int c, double scale) +image make_random_kernel(int size, int c, float scale) { int pad; if((pad=(size%2==0))) ++size; @@ -280,34 +280,34 @@ image get_image_layer(image m, int l) return out; } -double get_pixel(image m, int x, int y, int c) +float get_pixel(image m, int x, int y, int c) { assert(x < m.h && y < m.w && c < m.c); return m.data[c*m.h*m.w + x*m.w + y]; } -double get_pixel_extend(image m, int x, int y, int c) +float get_pixel_extend(image m, int x, int y, int c) { if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return 0; return get_pixel(m, x, y, c); } -void set_pixel(image m, int x, int y, int c, double val) +void set_pixel(image m, int x, int y, int c, float val) { assert(x < m.h && y < m.w && c < m.c); m.data[c*m.h*m.w + x*m.w + y] = val; } -void set_pixel_extend(image m, int x, int y, int c, double val) +void set_pixel_extend(image m, int x, int y, int c, float val) { if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return; set_pixel(m, x, y, c, val); } -void add_pixel(image m, int x, int y, int c, double val) +void add_pixel(image m, int x, int y, int c, float val) { assert(x < m.h && y < m.w && c < m.c); m.data[c*m.h*m.w + x*m.w + y] += val; } -void add_pixel_extend(image m, int x, int y, int c, double val) +void add_pixel_extend(image m, int x, int y, int c, float val) { if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return; add_pixel(m, x, y, c, val); @@ -329,7 +329,7 @@ void two_d_convolve(image m, int mc, image kernel, int kc, int stride, image out } for(x = xstart; x < xend; x += stride){ for(y = ystart; y < yend; y += stride){ - double sum = 0; + float sum = 0; for(i = 0; i < kernel.h; ++i){ for(j = 0; j < kernel.w; ++j){ sum += get_pixel(kernel, i, j, kc)*get_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, mc); @@ -340,9 +340,9 @@ void two_d_convolve(image m, int mc, image kernel, int kc, int stride, image out } } -double single_convolve(image m, image kernel, int x, int y) +float single_convolve(image m, image kernel, int x, int y) { - double sum = 0; + float sum = 0; int i, j, k; for(i = 0; i < kernel.h; ++i){ for(j = 0; j < kernel.w; ++j){ @@ -366,7 +366,7 @@ void convolve(image m, image kernel, int stride, int channel, image out, int edg int j; for(i = 0; i < m.h; i += stride){ for(j = 0; j < m.w; j += stride){ - double val = single_convolve(m, kernel, i, j); + float val = single_convolve(m, kernel, i, j); set_pixel(out, i/stride, j/stride, channel, val); } } @@ -380,20 +380,20 @@ void upsample_image(image m, int stride, image out) for(k = 0; k < m.c; ++k){ for(i = 0; i < m.h; ++i){ for(j = 0; j< m.w; ++j){ - double val = get_pixel(m, i, j, k); + float val = get_pixel(m, i, j, k); set_pixel(out, i*stride, j*stride, k, val); } } } } -void single_update(image m, image update, int x, int y, double error) +void single_update(image m, image update, int x, int y, float error) { int i, j, k; for(i = 0; i < update.h; ++i){ for(j = 0; j < update.w; ++j){ for(k = 0; k < update.c; ++k){ - double val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k); + float val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k); add_pixel(update, i, j, k, val*error); } } @@ -417,7 +417,7 @@ void kernel_update(image m, image update, int stride, int channel, image out, in } for(i = istart; i < iend; i += stride){ for(j = jstart; j < jend; j += stride){ - double error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel); + float error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel); single_update(m, update, i, j, error); } } @@ -428,13 +428,13 @@ void kernel_update(image m, image update, int stride, int channel, image out, in */ } -void single_back_convolve(image m, image kernel, int x, int y, double val) +void single_back_convolve(image m, image kernel, int x, int y, float val) { int i, j, k; for(i = 0; i < kernel.h; ++i){ for(j = 0; j < kernel.w; ++j){ for(k = 0; k < kernel.c; ++k){ - double pval = get_pixel(kernel, i, j, k) * val; + float pval = get_pixel(kernel, i, j, k) * val; add_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, k, pval); } } @@ -457,7 +457,7 @@ void back_convolve(image m, image kernel, int stride, int channel, image out, in } for(i = istart; i < iend; i += stride){ for(j = jstart; j < jend; j += stride){ - double val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel); + float val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel); single_back_convolve(m, kernel, i, j, val); } } diff --git a/src/image.h b/src/image.h index 18658575..72c4b2c0 100644 --- a/src/image.h +++ b/src/image.h @@ -7,18 +7,18 @@ typedef struct { int h; int w; int c; - double *data; + float *data; } image; -void scale_image(image m, double s); -void add_scalar_image(image m, double s); +void scale_image(image m, float s); +void add_scalar_image(image m, float s); void normalize_image(image p); void z_normalize_image(image p); -void threshold_image(image p, double t); +void threshold_image(image p, float t); void zero_image(image m); void rotate_image(image m); void subtract_image(image a, image b); -double avg_image_layer(image m, int l); +float avg_image_layer(image m, int l); void embed_image(image source, image dest, int h, int w); image collapse_image_layers(image source, int border); @@ -30,14 +30,14 @@ void print_image(image m); image make_image(int h, int w, int c); image make_empty_image(int h, int w, int c); image make_random_image(int h, int w, int c); -image make_random_kernel(int size, int c, double scale); -image double_to_image(int h, int w, int c, double *data); +image make_random_kernel(int size, int c, float scale); +image float_to_image(int h, int w, int c, float *data); image copy_image(image p); image load_image(char *filename); -double get_pixel(image m, int x, int y, int c); -double get_pixel_extend(image m, int x, int y, int c); -void set_pixel(image m, int x, int y, int c, double val); +float get_pixel(image m, int x, int y, int c); +float get_pixel_extend(image m, int x, int y, int c); +void set_pixel(image m, int x, int y, int c, float val); image get_image_layer(image m, int l); diff --git a/src/matrix.c b/src/matrix.c index 68e6f8d1..96bd3323 100644 --- a/src/matrix.c +++ b/src/matrix.c @@ -13,7 +13,7 @@ void free_matrix(matrix m) free(m.vals); } -double matrix_accuracy(matrix truth, matrix guess) +float matrix_accuracy(matrix truth, matrix guess) { int k = truth.cols; int i; @@ -22,7 +22,7 @@ double matrix_accuracy(matrix truth, matrix guess) int class = max_index(guess.vals[i], k); if(truth.vals[i][class]) ++count; } - return (double)count/truth.rows; + return (float)count/truth.rows; } void matrix_add_matrix(matrix from, matrix to) @@ -42,9 +42,9 @@ matrix make_matrix(int rows, int cols) matrix m; m.rows = rows; m.cols = cols; - m.vals = calloc(m.rows, sizeof(double *)); + m.vals = calloc(m.rows, sizeof(float *)); for(i = 0; i < m.rows; ++i){ - m.vals[i] = calloc(m.cols, sizeof(double)); + m.vals[i] = calloc(m.cols, sizeof(float)); } return m; } @@ -55,7 +55,7 @@ matrix hold_out_matrix(matrix *m, int n) matrix h; h.rows = n; h.cols = m->cols; - h.vals = calloc(h.rows, sizeof(double *)); + h.vals = calloc(h.rows, sizeof(float *)); for(i = 0; i < n; ++i){ int index = rand()%m->rows; h.vals[i] = m->vals[index]; @@ -64,9 +64,9 @@ matrix hold_out_matrix(matrix *m, int n) return h; } -double *pop_column(matrix *m, int c) +float *pop_column(matrix *m, int c) { - double *col = calloc(m->rows, sizeof(double)); + float *col = calloc(m->rows, sizeof(float)); int i, j; for(i = 0; i < m->rows; ++i){ col[i] = m->vals[i][c]; @@ -90,18 +90,18 @@ matrix csv_to_matrix(char *filename) int n = 0; int size = 1024; - m.vals = calloc(size, sizeof(double*)); + m.vals = calloc(size, sizeof(float*)); while((line = fgetl(fp))){ if(m.cols == -1) m.cols = count_fields(line); if(n == size){ size *= 2; - m.vals = realloc(m.vals, size*sizeof(double*)); + m.vals = realloc(m.vals, size*sizeof(float*)); } m.vals[n] = parse_fields(line, m.cols); free(line); ++n; } - m.vals = realloc(m.vals, n*sizeof(double*)); + m.vals = realloc(m.vals, n*sizeof(float*)); m.rows = n; return m; } diff --git a/src/matrix.h b/src/matrix.h index 098eb9ec..01d825dc 100644 --- a/src/matrix.h +++ b/src/matrix.h @@ -2,7 +2,7 @@ #define MATRIX_H typedef struct matrix{ int rows, cols; - double **vals; + float **vals; } matrix; matrix make_matrix(int rows, int cols); @@ -11,9 +11,9 @@ void print_matrix(matrix m); matrix csv_to_matrix(char *filename); matrix hold_out_matrix(matrix *m, int n); -double matrix_accuracy(matrix truth, matrix guess); +float matrix_accuracy(matrix truth, matrix guess); void matrix_add_matrix(matrix from, matrix to); -double *pop_column(matrix *m, int c); +float *pop_column(matrix *m, int c); #endif diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index ccf9bee0..8c409b94 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -6,7 +6,7 @@ image get_maxpool_image(maxpool_layer layer) int h = (layer.h-1)/layer.stride + 1; int w = (layer.w-1)/layer.stride + 1; int c = layer.c; - return double_to_image(h,w,c,layer.output); + return float_to_image(h,w,c,layer.output); } image get_maxpool_delta(maxpool_layer layer) @@ -14,7 +14,7 @@ image get_maxpool_delta(maxpool_layer layer) int h = (layer.h-1)/layer.stride + 1; int w = (layer.w-1)/layer.stride + 1; int c = layer.c; - return double_to_image(h,w,c,layer.delta); + return float_to_image(h,w,c,layer.delta); } maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride) @@ -25,41 +25,41 @@ maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride) layer->w = w; layer->c = c; layer->stride = stride; - layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double)); - layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double)); + layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float)); + layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float)); return layer; } -void forward_maxpool_layer(const maxpool_layer layer, double *in) +void forward_maxpool_layer(const maxpool_layer layer, float *in) { - image input = double_to_image(layer.h, layer.w, layer.c, in); + image input = float_to_image(layer.h, layer.w, layer.c, in); image output = get_maxpool_image(layer); int i,j,k; for(i = 0; i < output.h*output.w*output.c; ++i) output.data[i] = -DBL_MAX; for(k = 0; k < input.c; ++k){ for(i = 0; i < input.h; ++i){ for(j = 0; j < input.w; ++j){ - double val = get_pixel(input, i, j, k); - double cur = get_pixel(output, i/layer.stride, j/layer.stride, k); + float val = get_pixel(input, i, j, k); + float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); if(val > cur) set_pixel(output, i/layer.stride, j/layer.stride, k, val); } } } } -void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta) +void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta) { - image input = double_to_image(layer.h, layer.w, layer.c, in); - image input_delta = double_to_image(layer.h, layer.w, layer.c, delta); + image input = float_to_image(layer.h, layer.w, layer.c, in); + image input_delta = float_to_image(layer.h, layer.w, layer.c, delta); image output_delta = get_maxpool_delta(layer); image output = get_maxpool_image(layer); int i,j,k; for(k = 0; k < input.c; ++k){ for(i = 0; i < input.h; ++i){ for(j = 0; j < input.w; ++j){ - double val = get_pixel(input, i, j, k); - double cur = get_pixel(output, i/layer.stride, j/layer.stride, k); - double d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k); + float val = get_pixel(input, i, j, k); + float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); + float d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k); if(val == cur) { set_pixel(input_delta, i, j, k, d); } diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h index 0afe68a7..27d6f55a 100644 --- a/src/maxpool_layer.h +++ b/src/maxpool_layer.h @@ -6,14 +6,14 @@ typedef struct { int h,w,c; int stride; - double *delta; - double *output; + float *delta; + float *output; } maxpool_layer; image get_maxpool_image(maxpool_layer layer); maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride); -void forward_maxpool_layer(const maxpool_layer layer, double *in); -void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta); +void forward_maxpool_layer(const maxpool_layer layer, float *in); +void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta); #endif diff --git a/src/mini_blas.c b/src/mini_blas.c index 3af36e5c..b9a43049 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -1,8 +1,10 @@ #include +#include #include +#include -void pm(int M, int N, double *A) +void pm(int M, int N, float *A) { int i,j; for(i =0 ; i < M; ++i){ @@ -14,28 +16,37 @@ void pm(int M, int N, double *A) printf("\n"); } -void gemm(int TA, int TB, int M, int N, int K, double ALPHA, - double *A, int lda, - double *B, int ldb, - double BETA, - double *C, int ldc) +void gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) { - // Assume TA = 0, beta = 1 LULZ + // Assume beta = 1 LULZ int i,j,k; if(TB && !TA){ for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ - register double sum = 0; + register float sum = 0; for(k = 0; k < K; ++k){ sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; } C[i*ldc+j] += sum; } } + }else if(TA && !TB){ + for(i = 0; i < M; ++i){ + for(k = 0; k < K; ++k){ + register float A_PART = ALPHA*A[k*lda+i]; + for(j = 0; j < N; ++j){ + C[i*ldc+j] += A_PART*B[k*ldb+j]; + } + } + } }else{ for(i = 0; i < M; ++i){ for(k = 0; k < K; ++k){ - register double A_PART = ALPHA*A[i*lda+k]; + register float A_PART = ALPHA*A[i*lda+k]; for(j = 0; j < N; ++j){ C[i*ldc+j] += A_PART*B[k*ldb+j]; } @@ -44,7 +55,7 @@ void gemm(int TA, int TB, int M, int N, int K, double ALPHA, } } -void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix) +void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix) { int i; int mc = c; @@ -64,7 +75,7 @@ void im2row(double *image, int h, int w, int c, int size, int stride, double *ma matrix[i] = image[pc*h*w+ph*w+pw]; } } -void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix) +void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix) { int b,p; int blocks = ((h-size)/stride+1)*((w-size)/stride+1); @@ -84,9 +95,9 @@ void im2col(double *image, int h, int w, int c, int size, int stride, double *ma } //From Berkeley Vision's Caffe! -void im2col_cpu(double* data_im, const int channels, +void im2col_cpu(float* data_im, const int channels, const int height, const int width, const int ksize, const int stride, - double* data_col) + float* data_col) { int c,h,w; int height_col = (height - ksize) / stride + 1; @@ -106,3 +117,59 @@ void im2col_cpu(double* data_im, const int channels, } } +void col2im_cpu(float* data_col, const int channels, + const int height, const int width, const int ksize, const int stride, + float* data_im) +{ + int c,h,w; + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + int channels_col = channels * ksize * ksize; + for ( c = 0; c < channels_col; ++c) { + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int c_im = c / ksize / ksize; + for ( h = 0; h < height_col; ++h) { + for ( w = 0; w < width_col; ++w) { + data_im[(c_im * height + h * stride + h_offset) * width + + w * stride + w_offset]+= data_col[(c * height_col + h) * width_col + w]; + } + } + } +} + +float *random_matrix(int rows, int cols) +{ + int i; + float *m = calloc(rows*cols, sizeof(float)); + for(i = 0; i < rows*cols; ++i){ + m[i] = (float)rand()/RAND_MAX; + } + return m; +} + +void time_random_matrix(int TA, int TB, int m, int k, int n) +{ + float *a = random_matrix(m,k); + float *b = random_matrix(k,n); + float *c = random_matrix(m,n); + int i; + clock_t start = clock(), end; + for(i = 0; i<1000; ++i){ + gemm(TA,TB,m,n,k,1,a,k,b,n,1,c,n); + } + end = clock(); + printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (double)(end-start)/CLOCKS_PER_SEC); +} + +void test_blas() +{ + time_random_matrix(0,0,100,100,100); + time_random_matrix(1,0,100,100,100); + time_random_matrix(0,1,100,100,100); + + time_random_matrix(0,1,1000,100,100); + time_random_matrix(1,0,1000,100,100); + +} + diff --git a/src/mini_blas.h b/src/mini_blas.h index 46a37d3d..ff82a60c 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -1,11 +1,15 @@ -void pm(int M, int N, double *A); -void gemm(int TA, int TB, int M, int N, int K, double ALPHA, - double *A, int lda, - double *B, int ldb, - double BETA, - double *C, int ldc); -void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix); -void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix); -void im2col_cpu(double* data_im, const int channels, +void pm(int M, int N, float *A); +void gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc); +void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix); +void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix); +void im2col_cpu(float* data_im, const int channels, const int height, const int width, const int ksize, const int stride, - double* data_col); + float* data_col); +void col2im_cpu(float* data_col, const int channels, + const int height, const int width, const int ksize, const int stride, + float* data_im); +void test_blas(); diff --git a/src/network.c b/src/network.c index 2ce13d83..29e22e4e 100644 --- a/src/network.c +++ b/src/network.c @@ -21,7 +21,7 @@ network make_network(int n) return net; } -void forward_network(network net, double *input) +void forward_network(network net, float *input) { int i; for(i = 0; i < net.n; ++i){ @@ -48,7 +48,7 @@ void forward_network(network net, double *input) } } -void update_network(network net, double step, double momentum, double decay) +void update_network(network net, float step, float momentum, float decay) { int i; for(i = 0; i < net.n; ++i){ @@ -69,7 +69,7 @@ void update_network(network net, double step, double momentum, double decay) } } -double *get_network_output_layer(network net, int i) +float *get_network_output_layer(network net, int i) { if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; @@ -86,12 +86,12 @@ double *get_network_output_layer(network net, int i) } return 0; } -double *get_network_output(network net) +float *get_network_output(network net) { return get_network_output_layer(net, net.n-1); } -double *get_network_delta_layer(network net, int i) +float *get_network_delta_layer(network net, int i) { if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; @@ -109,16 +109,16 @@ double *get_network_delta_layer(network net, int i) return 0; } -double *get_network_delta(network net) +float *get_network_delta(network net) { return get_network_delta_layer(net, net.n-1); } -double calculate_error_network(network net, double *truth) +float calculate_error_network(network net, float *truth) { - double sum = 0; - double *delta = get_network_delta(net); - double *out = get_network_output(net); + float sum = 0; + float *delta = get_network_delta(net); + float *out = get_network_output(net); int i, k = get_network_output_size(net); for(i = 0; i < k; ++i){ delta[i] = truth[i] - out[i]; @@ -129,17 +129,17 @@ double calculate_error_network(network net, double *truth) int get_predicted_class_network(network net) { - double *out = get_network_output(net); + float *out = get_network_output(net); int k = get_network_output_size(net); return max_index(out, k); } -double backward_network(network net, double *input, double *truth) +float backward_network(network net, float *input, float *truth) { - double error = calculate_error_network(net, truth); + float error = calculate_error_network(net, truth); int i; - double *prev_input; - double *prev_delta; + float *prev_input; + float *prev_delta; for(i = net.n-1; i >= 0; --i){ if(i == 0){ prev_input = input; @@ -152,7 +152,7 @@ double backward_network(network net, double *input, double *truth) convolutional_layer layer = *(convolutional_layer *)net.layers[i]; learn_convolutional_layer(layer); //learn_convolutional_layer(layer); - //if(i != 0) backward_convolutional_layer(layer, prev_input, prev_delta); + if(i != 0) backward_convolutional_layer(layer, prev_delta); } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; @@ -171,49 +171,49 @@ double backward_network(network net, double *input, double *truth) return error; } -double train_network_datum(network net, double *x, double *y, double step, double momentum, double decay) +float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay) { forward_network(net, x); int class = get_predicted_class_network(net); - double error = backward_network(net, x, y); + float error = backward_network(net, x, y); update_network(net, step, momentum, decay); //return (y[class]?1:0); return error; } -double train_network_sgd(network net, data d, int n, double step, double momentum,double decay) +float train_network_sgd(network net, data d, int n, float step, float momentum,float decay) { int i; - double error = 0; + float error = 0; for(i = 0; i < n; ++i){ int index = rand()%d.X.rows; error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); //if((i+1)%10 == 0){ - // printf("%d: %f\n", (i+1), (double)correct/(i+1)); + // printf("%d: %f\n", (i+1), (float)correct/(i+1)); //} } return error/n; } -double train_network_batch(network net, data d, int n, double step, double momentum,double decay) +float train_network_batch(network net, data d, int n, float step, float momentum,float decay) { int i; int correct = 0; for(i = 0; i < n; ++i){ int index = rand()%d.X.rows; - double *x = d.X.vals[index]; - double *y = d.y.vals[index]; + float *x = d.X.vals[index]; + float *y = d.y.vals[index]; forward_network(net, x); int class = get_predicted_class_network(net); backward_network(net, x, y); correct += (y[class]?1:0); } update_network(net, step, momentum, decay); - return (double)correct/n; + return (float)correct/n; } -void train_network(network net, data d, double step, double momentum, double decay) +void train_network(network net, data d, float step, float momentum, float decay) { int i; int correct = 0; @@ -226,7 +226,7 @@ void train_network(network net, data d, double step, double momentum, double dec } visualize_network(net); cvWaitKey(100); - printf("Accuracy: %f\n", (double)correct/d.X.rows); + printf("Accuracy: %f\n", (float)correct/d.X.rows); } int get_network_output_size_layer(network net, int i) @@ -294,10 +294,10 @@ void visualize_network(network net) } } -double *network_predict(network net, double *input) +float *network_predict(network net, float *input) { forward_network(net, input); - double *out = get_network_output(net); + float *out = get_network_output(net); return out; } @@ -307,7 +307,7 @@ matrix network_predict_data(network net, data test) int k = get_network_output_size(net); matrix pred = make_matrix(test.X.rows, k); for(i = 0; i < test.X.rows; ++i){ - double *out = network_predict(net, test.X.vals[i]); + float *out = network_predict(net, test.X.vals[i]); for(j = 0; j < k; ++j){ pred.vals[i][j] = out[j]; } @@ -319,7 +319,7 @@ void print_network(network net) { int i,j; for(i = 0; i < net.n; ++i){ - double *output = 0; + float *output = 0; int n = 0; if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; @@ -343,8 +343,8 @@ void print_network(network net) output = layer.output; n = layer.inputs; } - double mean = mean_array(output, n); - double vari = variance_array(output, n); + float mean = mean_array(output, n); + float vari = variance_array(output, n); fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari); if(n > 100) n = 100; for(j = 0; j < n; ++j) fprintf(stderr, "%f, ", output[j]); @@ -353,10 +353,10 @@ void print_network(network net) } } -double network_accuracy(network net, data d) +float network_accuracy(network net, data d) { matrix guess = network_predict_data(net, d); - double acc = matrix_accuracy(d.y, guess); + float acc = matrix_accuracy(d.y, guess); free_matrix(guess); return acc; } diff --git a/src/network.h b/src/network.h index fa109dd0..17cc10bb 100644 --- a/src/network.h +++ b/src/network.h @@ -17,22 +17,22 @@ typedef struct { void **layers; LAYER_TYPE *types; int outputs; - double *output; + float *output; } network; network make_network(int n); -void forward_network(network net, double *input); -double backward_network(network net, double *input, double *truth); -void update_network(network net, double step, double momentum, double decay); -double train_network_sgd(network net, data d, int n, double step, double momentum,double decay); -double train_network_batch(network net, data d, int n, double step, double momentum,double decay); -void train_network(network net, data d, double step, double momentum, double decay); +void forward_network(network net, float *input); +float backward_network(network net, float *input, float *truth); +void update_network(network net, float step, float momentum, float decay); +float train_network_sgd(network net, data d, int n, float step, float momentum,float decay); +float train_network_batch(network net, data d, int n, float step, float momentum,float decay); +void train_network(network net, data d, float step, float momentum, float decay); matrix network_predict_data(network net, data test); -double network_accuracy(network net, data d); -double *get_network_output(network net); -double *get_network_output_layer(network net, int i); -double *get_network_delta_layer(network net, int i); -double *get_network_delta(network net); +float network_accuracy(network net, data d); +float *get_network_output(network net); +float *get_network_output_layer(network net, int i); +float *get_network_delta_layer(network net, int i); +float *get_network_delta(network net); int get_network_output_size_layer(network net, int i); int get_network_output_size(network net); image get_network_image(network net); diff --git a/src/option_list.c b/src/option_list.c index 1b32ebbb..7902cd9c 100644 --- a/src/option_list.c +++ b/src/option_list.c @@ -59,7 +59,7 @@ int option_find_int(list *l, char *key, int def) return def; } -double option_find_double(list *l, char *key, double def) +float option_find_float(list *l, char *key, float def) { char *v = option_find(l, key); if(v) return atof(v); diff --git a/src/option_list.h b/src/option_list.h index 0270465b..60e37fec 100644 --- a/src/option_list.h +++ b/src/option_list.h @@ -6,7 +6,7 @@ void option_insert(list *l, char *key, char *val); char *option_find(list *l, char *key); char *option_find_str(list *l, char *key, char *def); int option_find_int(list *l, char *key, int def); -double option_find_double(list *l, char *key, double def); +float option_find_float(list *l, char *key, float def); void option_unused(list *l); #endif diff --git a/src/softmax_layer.c b/src/softmax_layer.c index b213e5b0..1e01bd20 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -8,15 +8,16 @@ softmax_layer *make_softmax_layer(int inputs) fprintf(stderr, "Softmax Layer: %d inputs\n", inputs); softmax_layer *layer = calloc(1, sizeof(softmax_layer)); layer->inputs = inputs; - layer->output = calloc(inputs, sizeof(double)); - layer->delta = calloc(inputs, sizeof(double)); + layer->output = calloc(inputs, sizeof(float)); + layer->delta = calloc(inputs, sizeof(float)); return layer; } -void forward_softmax_layer(const softmax_layer layer, double *input) +/* UNSTABLE! +void forward_softmax_layer(const softmax_layer layer, float *input) { int i; - double sum = 0; + float sum = 0; for(i = 0; i < layer.inputs; ++i){ sum += exp(input[i]); } @@ -24,8 +25,25 @@ void forward_softmax_layer(const softmax_layer layer, double *input) layer.output[i] = exp(input[i])/sum; } } +*/ +void forward_softmax_layer(const softmax_layer layer, float *input) +{ + int i; + float sum = 0; + float largest = 0; + for(i = 0; i < layer.inputs; ++i){ + if(input[i] > largest) largest = input[i]; + } + for(i = 0; i < layer.inputs; ++i){ + sum += exp(input[i]-largest); + } + sum = largest+log(sum); + for(i = 0; i < layer.inputs; ++i){ + layer.output[i] = exp(input[i]-sum); + } +} -void backward_softmax_layer(const softmax_layer layer, double *input, double *delta) +void backward_softmax_layer(const softmax_layer layer, float *input, float *delta) { int i; for(i = 0; i < layer.inputs; ++i){ diff --git a/src/softmax_layer.h b/src/softmax_layer.h index 1a0d7605..bfcd390f 100644 --- a/src/softmax_layer.h +++ b/src/softmax_layer.h @@ -3,12 +3,12 @@ typedef struct { int inputs; - double *delta; - double *output; + float *delta; + float *output; } softmax_layer; softmax_layer *make_softmax_layer(int inputs); -void forward_softmax_layer(const softmax_layer layer, double *input); -void backward_softmax_layer(const softmax_layer layer, double *input, double *delta); +void forward_softmax_layer(const softmax_layer layer, float *input); +void backward_softmax_layer(const softmax_layer layer, float *input, float *delta); #endif diff --git a/src/tests.c b/src/tests.c index af22ddb8..00cd1a12 100644 --- a/src/tests.c +++ b/src/tests.c @@ -14,6 +14,9 @@ #include #include +#define _GNU_SOURCE +#include + void test_convolve() { image dog = load_image("dog.jpg"); @@ -26,7 +29,7 @@ void test_convolve() convolve(dog, kernel, 1, 0, edge, 1); } end = clock(); - printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); show_image_layers(edge, "Test Convolve"); } @@ -38,11 +41,11 @@ void test_convolve_matrix() int size = 11; int stride = 4; int n = 40; - double *filters = make_random_image(size, size, dog.c*n).data; + float *filters = make_random_image(size, size, dog.c*n).data; int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1); int mh = (size*size*dog.c); - double *matrix = calloc(mh*mw, sizeof(double)); + float *matrix = calloc(mh*mw, sizeof(float)); image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n); @@ -54,7 +57,7 @@ void test_convolve_matrix() gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); } end = clock(); - printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); show_image_layers(edge, "Test Convolve"); cvWaitKey(0); } @@ -72,11 +75,11 @@ void verify_convolutional_layer() int n = 1; int stride = 1; int size = 3; - double eps = .00000001; + float eps = .00000001; image test = make_random_image(5,5, 1); convolutional_layer layer = *make_convolutional_layer(test.h,test.w,test.c, n, size, stride, RELU); image out = get_convolutional_image(layer); - double **jacobian = calloc(test.h*test.w*test.c, sizeof(double)); + float **jacobian = calloc(test.h*test.w*test.c, sizeof(float)); forward_convolutional_layer(layer, test.data); image base = copy_image(out); @@ -90,19 +93,19 @@ void verify_convolutional_layer() jacobian[i] = partial.data; test.data[i] -= eps; } - double **jacobian2 = calloc(out.h*out.w*out.c, sizeof(double)); + float **jacobian2 = calloc(out.h*out.w*out.c, sizeof(float)); image in_delta = make_image(test.h, test.w, test.c); image out_delta = get_convolutional_delta(layer); for(i = 0; i < out.h*out.w*out.c; ++i){ out_delta.data[i] = 1; - //backward_convolutional_layer(layer, test.data, in_delta.data); + backward_convolutional_layer(layer, in_delta.data); image partial = copy_image(in_delta); jacobian2[i] = partial.data; out_delta.data[i] = 0; } int j; - double *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double)); - double *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double)); + float *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); + float *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); for(i = 0; i < test.h*test.w*test.c; ++i){ for(j =0 ; j < out.h*out.w*out.c; ++j){ j1[i*out.h*out.w*out.c + j] = jacobian[i][j]; @@ -112,12 +115,11 @@ void verify_convolutional_layer() } - image mj1 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1); - image mj2 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2); + image mj1 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1); + image mj2 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2); printf("%f %f\n", avg_image_layer(mj1,0), avg_image_layer(mj2,0)); show_image(mj1, "forward jacobian"); show_image(mj2, "backward jacobian"); - } void test_load() @@ -145,7 +147,7 @@ void test_rotate() rotate_image(dog); } end = clock(); - printf("Rotations: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + printf("Rotations: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); show_image(dog, "Test Rotate"); image random = make_random_image(3,3,3); @@ -159,18 +161,18 @@ void test_rotate() void test_parser() { network net = parse_network_cfg("test_parser.cfg"); - double input[1]; + float input[1]; int count = 0; - double avgerr = 0; + float avgerr = 0; while(++count < 100000000){ - double v = ((double)rand()/RAND_MAX); - double truth = v*v; + float v = ((float)rand()/RAND_MAX); + float truth = v*v; input[0] = v; forward_network(net, input); - double *out = get_network_output(net); - double *delta = get_network_delta(net); - double err = pow((out[0]-truth),2.); + float *out = get_network_output(net); + float *delta = get_network_delta(net); + float err = pow((out[0]-truth),2.); avgerr = .99 * avgerr + .01 * err; if(count % 1000000 == 0) printf("%f %f :%f AVG %f \n", truth, out[0], err, avgerr); delta[0] = truth - out[0]; @@ -192,9 +194,9 @@ void test_full() srand(0); int i = 0; char *labels[] = {"cat","dog"}; - double lr = .00001; - double momentum = .9; - double decay = 0.01; + float lr = .00001; + float momentum = .9; + float decay = 0.01; while(i++ < 1000 || 1){ data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2); train_network(net, train, lr, momentum, decay); @@ -207,32 +209,33 @@ void test_nist() { srand(444444); srand(888888); - network net = parse_network_cfg("nist_basic.cfg"); + network net = parse_network_cfg("nist.cfg"); data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10); normalize_data_rows(train); normalize_data_rows(test); //randomize_data(train); int count = 0; - double lr = .0005; - double momentum = .9; - double decay = 0.01; + float lr = .0005; + float momentum = .9; + float decay = 0.01; clock_t start = clock(), end; while(++count <= 100){ - visualize_network(net); - double loss = train_network_sgd(net, train, 10000, lr, momentum, decay); + //visualize_network(net); + float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, loss, lr, momentum, decay); end = clock(); - printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); start=end; cvWaitKey(100); //lr /= 2; if(count%5 == 0){ - double train_acc = network_accuracy(net, train); + float train_acc = network_accuracy(net, train); fprintf(stderr, "\nTRAIN: %f\n", train_acc); - double test_acc = network_accuracy(net, test); + float test_acc = network_accuracy(net, test); fprintf(stderr, "TEST: %f\n\n", test_acc); printf("%d, %f, %f\n", count, train_acc, test_acc); + lr *= .5; } } } @@ -253,24 +256,24 @@ void test_ensemble() int n = 30; for(i = 0; i < n; ++i){ int count = 0; - double lr = .0005; - double momentum = .9; - double decay = .01; + float lr = .0005; + float momentum = .9; + float decay = .01; network net = parse_network_cfg("nist.cfg"); while(++count <= 15){ - double acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay); + float acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay); printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay ); lr /= 2; } matrix partial = network_predict_data(net, test); - double acc = matrix_accuracy(test.y, partial); + float acc = matrix_accuracy(test.y, partial); printf("Model Accuracy: %lf\n", acc); matrix_add_matrix(partial, prediction); acc = matrix_accuracy(test.y, prediction); printf("Current Ensemble Accuracy: %lf\n", acc); free_matrix(partial); } - double acc = matrix_accuracy(test.y, prediction); + float acc = matrix_accuracy(test.y, prediction); printf("Full Ensemble Accuracy: %lf\n", acc); } @@ -279,19 +282,19 @@ void test_random_classify() network net = parse_network_cfg("connected.cfg"); matrix m = csv_to_matrix("train.csv"); //matrix ho = hold_out_matrix(&m, 2500); - double *truth = pop_column(&m, 0); - //double *ho_truth = pop_column(&ho, 0); + float *truth = pop_column(&m, 0); + //float *ho_truth = pop_column(&ho, 0); int i; clock_t start = clock(), end; int count = 0; while(++count <= 300){ for(i = 0; i < m.rows; ++i){ int index = rand()%m.rows; - //image p = double_to_image(1690,1,1,m.vals[index]); + //image p = float_to_image(1690,1,1,m.vals[index]); //normalize_image(p); forward_network(net, m.vals[index]); - double *out = get_network_output(net); - double *delta = get_network_delta(net); + float *out = get_network_output(net); + float *delta = get_network_delta(net); //printf("%f\n", out[0]); delta[0] = truth[index] - out[0]; // printf("%f\n", delta[0]); @@ -299,8 +302,8 @@ void test_random_classify() //backward_network(net, m.vals[index], ); update_network(net, .00001, 0,0); } - //double test_acc = error_network(net, m, truth); - //double valid_acc = error_network(net, ho, ho_truth); + //float test_acc = error_network(net, m, truth); + //float valid_acc = error_network(net, ho, ho_truth); //printf("%f, %f\n", test_acc, valid_acc); //fprintf(stderr, "%5d: %f Valid: %f\n",count, test_acc, valid_acc); //if(valid_acc > .70) break; @@ -311,12 +314,12 @@ void test_random_classify() truth = pop_column(&test, 0); for(i = 0; i < test.rows; ++i){ forward_network(net, test.vals[i]); - double *out = get_network_output(net); + float *out = get_network_output(net); if(fabs(out[0]) < .5) fprintf(fp, "0\n"); else fprintf(fp, "1\n"); } fclose(fp); - printf("Neural Net Learning: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + printf("Neural Net Learning: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); } void test_split() @@ -326,30 +329,6 @@ void test_split() printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows); } -double *random_matrix(int rows, int cols) -{ - int i, j; - double *m = calloc(rows*cols, sizeof(double)); - for(i = 0; i < rows; ++i){ - for(j = 0; j < cols; ++j){ - m[i*cols+j] = (double)rand()/RAND_MAX; - } - } - return m; -} - -void test_blas() -{ - int m = 1000, n = 1000, k = 1000; - double *a = random_matrix(m,k); - double *b = random_matrix(k,n); - double *c = random_matrix(m,n); - int i; - for(i = 0; i<1000; ++i){ - gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); - } -} - void test_im2row() { int h = 20; @@ -362,16 +341,18 @@ void test_im2row() int mw = ((h-size)/stride+1)*((w-size)/stride+1); int mh = (size*size*c); int msize = mc*mw*mh; - double *matrix = calloc(msize, sizeof(double)); + float *matrix = calloc(msize, sizeof(float)); int i; for(i = 0; i < 1000; ++i){ im2col_cpu(test.data, c, h, w, size, stride, matrix); - image render = double_to_image(mh, mw, mc, matrix); + image render = float_to_image(mh, mw, mc, matrix); } } int main() { + //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); + //test_blas(); //test_convolve_matrix(); // test_im2row(); diff --git a/src/utils.c b/src/utils.c index 5180fe65..41ee7681 100644 --- a/src/utils.c +++ b/src/utils.c @@ -123,9 +123,9 @@ int count_fields(char *line) return count; } -double *parse_fields(char *line, int n) +float *parse_fields(char *line, int n) { - double *field = calloc(n, sizeof(double)); + float *field = calloc(n, sizeof(float)); char *c, *p, *end; int count = 0; int done = 0; @@ -143,36 +143,36 @@ double *parse_fields(char *line, int n) return field; } -double mean_array(double *a, int n) +float mean_array(float *a, int n) { int i; - double sum = 0; + float sum = 0; for(i = 0; i < n; ++i) sum += a[i]; return sum/n; } -double variance_array(double *a, int n) +float variance_array(float *a, int n) { int i; - double sum = 0; - double mean = mean_array(a, n); + float sum = 0; + float mean = mean_array(a, n); for(i = 0; i < n; ++i) sum += (a[i] - mean)*(a[i]-mean); - double variance = sum/n; + float variance = sum/n; return variance; } -double constrain(double a, double max) +float constrain(float a, float max) { if(a > abs(max)) return abs(max); if(a < -abs(max)) return -abs(max); return a; } -void normalize_array(double *a, int n) +void normalize_array(float *a, int n) { int i; - double mu = mean_array(a,n); - double sigma = sqrt(variance_array(a,n)); + float mu = mean_array(a,n); + float sigma = sqrt(variance_array(a,n)); for(i = 0; i < n; ++i){ a[i] = (a[i] - mu)/sigma; } @@ -180,7 +180,7 @@ void normalize_array(double *a, int n) sigma = sqrt(variance_array(a,n)); } -void translate_array(double *a, int n, double s) +void translate_array(float *a, int n, float s) { int i; for(i = 0; i < n; ++i){ @@ -188,18 +188,18 @@ void translate_array(double *a, int n, double s) } } -void scale_array(double *a, int n, double s) +void scale_array(float *a, int n, float s) { int i; for(i = 0; i < n; ++i){ a[i] *= s; } } -int max_index(double *a, int n) +int max_index(float *a, int n) { if(n <= 0) return -1; int i, max_i = 0; - double max = a[0]; + float max = a[0]; for(i = 1; i < n; ++i){ if(a[i] > max){ max = a[i]; @@ -209,20 +209,20 @@ int max_index(double *a, int n) return max_i; } -double rand_normal() +float rand_normal() { int i; - double sum= 0; - for(i = 0; i < 12; ++i) sum += (double)rand()/RAND_MAX; + float sum= 0; + for(i = 0; i < 12; ++i) sum += (float)rand()/RAND_MAX; return sum-6.; } -double **one_hot_encode(double *a, int n, int k) +float **one_hot_encode(float *a, int n, int k) { int i; - double **t = calloc(n, sizeof(double*)); + float **t = calloc(n, sizeof(float*)); for(i = 0; i < n; ++i){ - t[i] = calloc(k, sizeof(double)); + t[i] = calloc(k, sizeof(float)); int index = (int)a[i]; t[i][index] = 1; } diff --git a/src/utils.h b/src/utils.h index cf380166..8185107e 100644 --- a/src/utils.h +++ b/src/utils.h @@ -13,15 +13,15 @@ char *fgetl(FILE *fp); list *parse_csv_line(char *line); char *copy_string(char *s); int count_fields(char *line); -double *parse_fields(char *line, int n); -void normalize_array(double *a, int n); -void scale_array(double *a, int n, double s); -void translate_array(double *a, int n, double s); -int max_index(double *a, int n); -double constrain(double a, double max); -double rand_normal(); -double mean_array(double *a, int n); -double variance_array(double *a, int n); -double **one_hot_encode(double *a, int n, int k); +float *parse_fields(char *line, int n); +void normalize_array(float *a, int n); +void scale_array(float *a, int n, float s); +void translate_array(float *a, int n, float s); +int max_index(float *a, int n); +float constrain(float a, float max); +float rand_normal(); +float mean_array(float *a, int n); +float variance_array(float *a, int n); +float **one_hot_encode(float *a, int n, int k); #endif