So there WAS this huge bug. Gone now

pull/5299/head
Joseph Redmon 11 years ago
parent 5ef74c2031
commit cd8d53df21
  1. 5
      Makefile
  2. 56
      src/activations.c
  3. 28
      src/activations.cl
  4. 8
      src/activations.h
  5. 14
      src/axpy.c
  6. 0
      src/axpy.cl
  7. 0
      src/col2im.c
  8. 0
      src/col2im.cl
  9. 32
      src/connected_layer.c
  10. 7
      src/connected_layer.h
  11. 123
      src/convolutional_layer.c
  12. 23
      src/convolutional_layer.h
  13. 283
      src/gemm.c
  14. 121
      src/im2col.c
  15. 26
      src/im2col.cl
  16. 33
      src/mini_blas.h
  17. 103
      src/network.c
  18. 9
      src/network.h
  19. 14
      src/opencl.c
  20. 8
      src/opencl.h
  21. 3
      src/parser.c
  22. 32
      src/tests.c

@ -1,18 +1,19 @@
CC=gcc CC=gcc
GPU=1 GPU=0
COMMON=-Wall -Werror -Wfatal-errors `pkg-config --cflags opencv` -I/usr/local/cuda/include/ COMMON=-Wall -Werror -Wfatal-errors `pkg-config --cflags opencv` -I/usr/local/cuda/include/
ifeq ($(GPU), 1) ifeq ($(GPU), 1)
COMMON+=-DGPU COMMON+=-DGPU
else else
endif endif
UNAME = $(shell uname) UNAME = $(shell uname)
OPTS=-O3 -flto OPTS=-Ofast -flto
ifeq ($(UNAME), Darwin) ifeq ($(UNAME), Darwin)
COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
ifeq ($(GPU), 1) ifeq ($(GPU), 1)
LDFLAGS= -framework OpenCL LDFLAGS= -framework OpenCL
endif endif
else else
OPTS+= -march=native
ifeq ($(GPU), 1) ifeq ($(GPU), 1)
LDFLAGS= -lOpenCL LDFLAGS= -lOpenCL
endif endif

@ -2,6 +2,7 @@
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include <string.h> #include <string.h>
char *get_activation_string(ACTIVATION a) char *get_activation_string(ACTIVATION a)
@ -40,27 +41,29 @@ float relu_activate(float x){return x*(x>0);}
float ramp_activate(float x){return x*(x>0)+.1*x;} float ramp_activate(float x){return x*(x>0)+.1*x;}
float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
float activate(float x, ACTIVATION a){ float activate(float x, ACTIVATION a, float dropout)
{
if((float)rand()/RAND_MAX < dropout) return 0;
switch(a){ switch(a){
case LINEAR: case LINEAR:
return linear_activate(x); return linear_activate(x)/(1-dropout);
case SIGMOID: case SIGMOID:
return sigmoid_activate(x); return sigmoid_activate(x)/(1-dropout);
case RELU: case RELU:
return relu_activate(x); return relu_activate(x)/(1-dropout);
case RAMP: case RAMP:
return ramp_activate(x); return ramp_activate(x)/(1-dropout);
case TANH: case TANH:
return tanh_activate(x); return tanh_activate(x)/(1-dropout);
} }
return 0; return 0;
} }
void activate_array(float *x, const int n, const ACTIVATION a) void activate_array(float *x, const int n, const ACTIVATION a, float dropout)
{ {
int i; int i;
for(i = 0; i < n; ++i){ for(i = 0; i < n; ++i){
x[i] = activate(x[i], a); x[i] = activate(x[i], a, dropout);
} }
} }
@ -89,3 +92,40 @@ void gradient_array(const float *x, const int n, const ACTIVATION a, float *delt
} }
} }
#ifdef GPU
#include "opencl.h"
#include <math.h>
cl_kernel get_activation_kernel()
{
static int init = 0;
static cl_kernel kernel;
if(!init){
kernel = get_kernel("src/activations.cl", "activate_array", 0);
init = 1;
}
return kernel;
}
void activate_array_ongpu(cl_mem x, int n, ACTIVATION a, float dropout)
{
cl_setup();
cl_kernel kernel = get_activation_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(kernel, i++, sizeof(x), (void*) &x);
cl.error = clSetKernelArg(kernel, i++, sizeof(n), (void*) &n);
cl.error = clSetKernelArg(kernel, i++, sizeof(a), (void*) &a);
cl.error = clSetKernelArg(kernel, i++, sizeof(dropout),
(void*) &dropout);
check_error(cl);
size_t gsize = n;
clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0);
check_error(cl);
}
#endif

@ -0,0 +1,28 @@
typedef enum{
SIGMOID, RELU, LINEAR, RAMP, TANH
}ACTIVATION;
float activate(float x, ACTIVATION a, float dropout)
{
//if((float)rand()/RAND_MAX < dropout) return 0;
switch(a){
case LINEAR:
return linear_activate(x)/(1-dropout);
case SIGMOID:
return sigmoid_activate(x)/(1-dropout);
case RELU:
return relu_activate(x)/(1-dropout);
case RAMP:
return ramp_activate(x)/(1-dropout);
case TANH:
return tanh_activate(x)/(1-dropout);
}
return 0;
}
__kernel void activate_array(__global float *x,
const int n, const ACTIVATION a, const float dropout)
{
int i = get_global_id(0);
x[i] = activate(x[i], a, dropout);
}

@ -1,3 +1,4 @@
#include "opencl.h"
#ifndef ACTIVATIONS_H #ifndef ACTIVATIONS_H
#define ACTIVATIONS_H #define ACTIVATIONS_H
@ -8,10 +9,13 @@ typedef enum{
ACTIVATION get_activation(char *s); ACTIVATION get_activation(char *s);
char *get_activation_string(ACTIVATION a); char *get_activation_string(ACTIVATION a);
float activate(float x, ACTIVATION a); float activate(float x, ACTIVATION a, float dropout);
float gradient(float x, ACTIVATION a); float gradient(float x, ACTIVATION a);
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta); void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
void activate_array(float *x, const int n, const ACTIVATION a); void activate_array(float *x, const int n, const ACTIVATION a, float dropout);
#ifdef GPU
void activate_array_ongpu(cl_mem x, int n, ACTIVATION a, float dropout);
#endif
#endif #endif

@ -0,0 +1,14 @@
#include "mini_blas.h"
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
int i;
for(i = 0; i < N; ++i) Y[i*INCY] += ALPHA*X[i*INCX];
}
void scal_cpu(int N, float ALPHA, float *X, int INCX)
{
int i;
for(i = 0; i < N; ++i) X[i*INCX] *= ALPHA;
}

@ -7,7 +7,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation) connected_layer *make_connected_layer(int batch, int inputs, int outputs, float dropout, ACTIVATION activation)
{ {
fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs); fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs);
int i; int i;
@ -15,6 +15,7 @@ connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVA
layer->inputs = inputs; layer->inputs = inputs;
layer->outputs = outputs; layer->outputs = outputs;
layer->batch=batch; layer->batch=batch;
layer->dropout = dropout;
layer->output = calloc(batch*outputs, sizeof(float*)); layer->output = calloc(batch*outputs, sizeof(float*));
layer->delta = calloc(batch*outputs, sizeof(float*)); layer->delta = calloc(batch*outputs, sizeof(float*));
@ -54,9 +55,9 @@ void update_connected_layer(connected_layer layer, float step, float momentum, f
memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float)); memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float));
} }
void forward_connected_layer(connected_layer layer, float *input) void forward_connected_layer(connected_layer layer, float *input, int train)
{ {
int i; if(!train) layer.dropout = 0;
memcpy(layer.output, layer.biases, layer.outputs*sizeof(float)); memcpy(layer.output, layer.biases, layer.outputs*sizeof(float));
int m = layer.batch; int m = layer.batch;
int k = layer.inputs; int k = layer.inputs;
@ -65,17 +66,15 @@ void forward_connected_layer(connected_layer layer, float *input)
float *b = layer.weights; float *b = layer.weights;
float *c = layer.output; float *c = layer.output;
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
for(i = 0; i < layer.outputs*layer.batch; ++i){ activate_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.dropout);
layer.output[i] = activate(layer.output[i], layer.activation);
}
} }
void learn_connected_layer(connected_layer layer, float *input) void backward_connected_layer(connected_layer layer, float *input, float *delta)
{ {
int i; int i;
for(i = 0; i < layer.outputs*layer.batch; ++i){ for(i = 0; i < layer.outputs*layer.batch; ++i){
layer.delta[i] *= gradient(layer.output[i], layer.activation); layer.delta[i] *= gradient(layer.output[i], layer.activation);
layer.bias_updates[i%layer.batch] += layer.delta[i]/layer.batch; layer.bias_updates[i%layer.batch] += layer.delta[i];
} }
int m = layer.inputs; int m = layer.inputs;
int k = layer.batch; int k = layer.batch;
@ -84,18 +83,15 @@ void learn_connected_layer(connected_layer layer, float *input)
float *b = layer.delta; float *b = layer.delta;
float *c = layer.weight_updates; float *c = layer.weight_updates;
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
void backward_connected_layer(connected_layer layer, float *input, float *delta) m = layer.inputs;
{ k = layer.outputs;
int m = layer.inputs; n = layer.batch;
int k = layer.outputs;
int n = layer.batch;
float *a = layer.weights; a = layer.weights;
float *b = layer.delta; b = layer.delta;
float *c = delta; c = delta;
gemm(0,0,m,n,k,1,a,k,b,n,0,c,n); if(c) gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
} }

@ -21,16 +21,17 @@ typedef struct{
float *output; float *output;
float *delta; float *delta;
float dropout;
ACTIVATION activation; ACTIVATION activation;
} connected_layer; } connected_layer;
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation); connected_layer *make_connected_layer(int batch, int inputs, int outputs, float dropout, ACTIVATION activation);
void forward_connected_layer(connected_layer layer, float *input); void forward_connected_layer(connected_layer layer, float *input, int train);
void backward_connected_layer(connected_layer layer, float *input, float *delta); void backward_connected_layer(connected_layer layer, float *input, float *delta);
void learn_connected_layer(connected_layer layer, float *input);
void update_connected_layer(connected_layer layer, float step, float momentum, float decay); void update_connected_layer(connected_layer layer, float step, float momentum, float decay);

@ -55,7 +55,7 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform()); for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform());
for(i = 0; i < n; ++i){ for(i = 0; i < n; ++i){
//layer->biases[i] = rand_normal()*scale + scale; //layer->biases[i] = rand_normal()*scale + scale;
layer->biases[i] = 0; layer->biases[i] = .5;
} }
int out_h = convolutional_out_height(*layer); int out_h = convolutional_out_height(*layer);
int out_w = convolutional_out_width(*layer); int out_w = convolutional_out_width(*layer);
@ -63,6 +63,8 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in
layer->col_image = calloc(layer->batch*out_h*out_w*size*size*c, sizeof(float)); layer->col_image = calloc(layer->batch*out_h*out_w*size*size*c, sizeof(float));
layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float));
layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float));
#ifdef GPU
#endif
layer->activation = activation; layer->activation = activation;
fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@ -87,48 +89,76 @@ void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c)
layer->batch*out_h * out_w * layer->n*sizeof(float)); layer->batch*out_h * out_w * layer->n*sizeof(float));
} }
void bias_output(const convolutional_layer layer)
{
int i,j;
int out_h = convolutional_out_height(layer);
int out_w = convolutional_out_width(layer);
for(i = 0; i < layer.n; ++i){
for(j = 0; j < out_h*out_w; ++j){
layer.output[i*out_h*out_w + j] = layer.biases[i];
}
}
}
void forward_convolutional_layer(const convolutional_layer layer, float *in) void forward_convolutional_layer(const convolutional_layer layer, float *in)
{ {
int i; int out_h = convolutional_out_height(layer);
int out_w = convolutional_out_width(layer);
int m = layer.n; int m = layer.n;
int k = layer.size*layer.size*layer.c; int k = layer.size*layer.size*layer.c;
int n = convolutional_out_height(layer)* int n = out_h*out_w*layer.batch;
convolutional_out_width(layer)*
layer.batch;
float *a = layer.filters; float *a = layer.filters;
float *b = layer.col_image; float *b = layer.col_image;
float *c = layer.output; float *c = layer.output;
for(i = 0; i < layer.batch; ++i){ im2col_cpu(in,layer.batch, layer.c, layer.h, layer.w,
im2col_gpu(in+i*(n/layer.batch), layer.c, layer.h, layer.w, layer.size, layer.stride, b+i*(n/layer.batch)); layer.size, layer.stride, b);
} bias_output(layer);
gemm(0,0,m,n,k,1,a,k,b,n,0,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
activate_array(layer.output, m*n, layer.activation); activate_array(layer.output, m*n, layer.activation, 0.);
} }
#ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in)
{
int m = layer.n;
int k = layer.size*layer.size*layer.c;
int n = convolutional_out_height(layer)*
convolutional_out_width(layer)*
layer.batch;
cl_write_array(layer.filters_cl, layer.filters, m*k);
cl_mem a = layer.filters_cl;
cl_mem b = layer.col_image_cl;
cl_mem c = layer.output_cl;
im2col_ongpu(in, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, b);
gemm_ongpu(0,0,m,n,k,1,a,k,b,n,0,c,n);
activate_array_ongpu(layer.output_cl, m*n, layer.activation, 0.);
cl_read_array(layer.output_cl, layer.output, m*n);
}
#endif
void learn_bias_convolutional_layer(convolutional_layer layer) void learn_bias_convolutional_layer(convolutional_layer layer)
{ {
int i,j,b; int i,b;
int size = convolutional_out_height(layer) int size = convolutional_out_height(layer)
*convolutional_out_width(layer); *convolutional_out_width(layer);
for(b = 0; b < layer.batch; ++b){ for(b = 0; b < layer.batch; ++b){
for(i = 0; i < layer.n; ++i){ for(i = 0; i < layer.n; ++i){
float sum = 0; layer.bias_updates[i] += mean_array(layer.delta+size*(i+b*layer.n), size);
for(j = 0; j < size; ++j){
sum += layer.delta[j+size*(i+b*layer.n)];
}
layer.bias_updates[i] += sum/size;
} }
} }
} }
void learn_convolutional_layer(convolutional_layer layer) void backward_convolutional_layer(convolutional_layer layer, float *delta)
{ {
int m = layer.n; int m = layer.n;
int n = layer.size*layer.size*layer.c; int n = layer.size*layer.size*layer.c;
int k = convolutional_out_height(layer)* int k = convolutional_out_height(layer)*
convolutional_out_width(layer)* convolutional_out_width(layer)*
layer.batch; layer.batch;
gradient_array(layer.output, m*k, layer.activation, layer.delta); gradient_array(layer.output, m*k, layer.activation, layer.delta);
learn_bias_convolutional_layer(layer); learn_bias_convolutional_layer(layer);
@ -137,26 +167,25 @@ void learn_convolutional_layer(convolutional_layer layer)
float *c = layer.filter_updates; float *c = layer.filter_updates;
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
}
void backward_convolutional_layer(convolutional_layer layer, float *delta) if(delta){
{ int i;
int i; m = layer.size*layer.size*layer.c;
int m = layer.size*layer.size*layer.c; k = layer.n;
int k = layer.n; n = convolutional_out_height(layer)*
int n = convolutional_out_height(layer)*
convolutional_out_width(layer)* convolutional_out_width(layer)*
layer.batch; layer.batch;
float *a = layer.filters; a = layer.filters;
float *b = layer.delta; b = layer.delta;
float *c = layer.col_image; c = layer.col_image;
gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < layer.batch; ++i){
col2im_cpu(c+i*n/layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, delta+i*n/layer.batch); col2im_cpu(c+i*n/layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, delta+i*n/layer.batch);
}
} }
} }
@ -171,32 +200,6 @@ void update_convolutional_layer(convolutional_layer layer, float step, float mom
scal_cpu(size, momentum, layer.filter_updates, 1); scal_cpu(size, momentum, layer.filter_updates, 1);
} }
void test_convolutional_layer()
{
convolutional_layer l = *make_convolutional_layer(1,4,4,1,1,3,1,LINEAR);
float input[] = {1,2,3,4,
5,6,7,8,
9,10,11,12,
13,14,15,16};
float filter[] = {.5, 0, .3,
0 , 1, 0,
.2 , 0, 1};
float delta[] = {1, 2,
3, 4};
float in_delta[] = {.5,1,.3,.6,
5,6,7,8,
9,10,11,12,
13,14,15,16};
l.filters = filter;
forward_convolutional_layer(l, input);
l.delta = delta;
learn_convolutional_layer(l);
image filter_updates = float_to_image(3,3,1,l.filter_updates);
print_image(filter_updates);
printf("Delta:\n");
backward_convolutional_layer(l, in_delta);
pm(4,4,in_delta);
}
image get_convolutional_filter(convolutional_layer layer, int i) image get_convolutional_filter(convolutional_layer layer, int i)
{ {

@ -1,6 +1,10 @@
#ifndef CONVOLUTIONAL_LAYER_H #ifndef CONVOLUTIONAL_LAYER_H
#define CONVOLUTIONAL_LAYER_H #define CONVOLUTIONAL_LAYER_H
#ifdef GPU
#include "opencl.h"
#endif
#include "image.h" #include "image.h"
#include "activations.h" #include "activations.h"
@ -22,13 +26,30 @@ typedef struct {
float *delta; float *delta;
float *output; float *output;
#ifdef GPU
cl_mem filters_cl;
cl_mem filter_updates_cl;
cl_mem filter_momentum_cl;
cl_mem biases_cl;
cl_mem bias_updates_cl;
cl_mem bias_momentum_cl;
cl_mem col_image_cl;
cl_mem delta_cl;
cl_mem output_cl;
#endif
ACTIVATION activation; ACTIVATION activation;
} convolutional_layer; } convolutional_layer;
#ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in);
#endif
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation);
void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c); void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c);
void forward_convolutional_layer(const convolutional_layer layer, float *in); void forward_convolutional_layer(const convolutional_layer layer, float *in);
void learn_convolutional_layer(convolutional_layer layer);
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay); void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters); image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);

@ -0,0 +1,283 @@
#include "mini_blas.h"
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
#ifdef GPU
gemm_gpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
#else
gemm_cpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
#endif
}
void gemm_nn(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
for(i = 0; i < M; ++i){
for(k = 0; k < K; ++k){
register float A_PART = ALPHA*A[i*lda+k];
for(j = 0; j < N; ++j){
C[i*ldc+j] += A_PART*B[k*ldb+j];
}
}
}
}
void gemm_nt(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
register float sum = 0;
for(k = 0; k < K; ++k){
sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
}
C[i*ldc+j] += sum;
}
}
}
void gemm_tn(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
for(i = 0; i < M; ++i){
for(k = 0; k < K; ++k){
register float A_PART = ALPHA*A[k*lda+i];
for(j = 0; j < N; ++j){
C[i*ldc+j] += A_PART*B[k*ldb+j];
}
}
}
}
void gemm_tt(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
for(k = 0; k < K; ++k){
C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb];
}
}
}
}
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
int i, j;
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
C[i*ldc + j] *= BETA;
}
}
if(!TA && !TB)
gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(TA && !TB)
gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(!TA && TB)
gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else
gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}
#ifdef GPU
#include "opencl.h"
#include <math.h>
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define BLOCK 8
cl_kernel get_gemm_kernel()
{
static int init = 0;
static cl_kernel gemm_kernel;
if(!init){
gemm_kernel = get_kernel("src/gemm.cl", "gemm", "-D BLOCK=" STR(BLOCK) );
init = 1;
}
return gemm_kernel;
}
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
cl_mem A_gpu, int lda,
cl_mem B_gpu, int ldb,
float BETA,
cl_mem C_gpu, int ldc)
{
cl_setup();
cl_kernel gemm_kernel = get_gemm_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
check_error(cl);
const size_t global_size[] = {ceil((float)M/BLOCK)*BLOCK, ceil((float)N/BLOCK)*BLOCK};
const size_t local_size[] = {BLOCK, BLOCK};
clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
check_error(cl);
}
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
cl_setup();
cl_context context = cl.context;
cl_command_queue queue = cl.queue;
size_t size = sizeof(float)*(TA ? lda*K:lda*M);
cl_mem A_gpu = clCreateBuffer(context,
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
size, A, &cl.error);
check_error(cl);
size = sizeof(float)*(TB ? ldb*N:ldb*K);
cl_mem B_gpu = clCreateBuffer(context,
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
size, B, &cl.error);
check_error(cl);
size = sizeof(float)*(ldc*M);
cl_mem C_gpu = clCreateBuffer(context,
CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR,
size, C, &cl.error);
check_error(cl);
gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0);
check_error(cl);
clReleaseMemObject(A_gpu);
clReleaseMemObject(B_gpu);
clReleaseMemObject(C_gpu);
}
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
void time_gpu_random_matrix(int TA, int TB, int m, int k, int n)
{
float *a;
if(!TA) a = random_matrix(m,k);
else a = random_matrix(k,m);
int lda = (!TA)?k:m;
float *b;
if(!TB) b = random_matrix(k,n);
else b = random_matrix(n,k);
int ldb = (!TB)?n:k;
float *c = random_matrix(m,n);
int i;
clock_t start = clock(), end;
for(i = 0; i<1000; ++i){
gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
}
end = clock();
printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
free(a);
free(b);
free(c);
}
void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
{
srand(0);
float *a;
if(!TA) a = random_matrix(m,k);
else a = random_matrix(k,m);
int lda = (!TA)?k:m;
float *b;
if(!TB) b = random_matrix(k,n);
else b = random_matrix(n,k);
int ldb = (!TB)?n:k;
float *c = random_matrix(m,n);
float *c_gpu = random_matrix(m,n);
memset(c, 0, m*n*sizeof(float));
memset(c_gpu, 0, m*n*sizeof(float));
int i;
//pm(m,k,b);
gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n);
//pm(m, n, c_gpu);
gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
//pm(m, n, c);
double sse = 0;
for(i = 0; i < m*n; ++i) {
//printf("%f %f\n", c[i], c_gpu[i]);
sse += pow(c[i]-c_gpu[i], 2);
}
printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g MSE\n",m,k,k,n, TA, TB, sse/(m*n));
free(a);
free(b);
free(c);
}
void test_gpu_blas()
{
test_gpu_accuracy(0,0,17,10,10);
test_gpu_accuracy(1,0,17,10,10);
test_gpu_accuracy(0,1,17,10,10);
test_gpu_accuracy(1,1,17,10,10);
test_gpu_accuracy(0,0,1000,10,100);
test_gpu_accuracy(1,0,1000,10,100);
test_gpu_accuracy(0,1,1000,10,100);
test_gpu_accuracy(1,1,1000,10,100);
time_gpu_random_matrix(0,0,1000,1000,100);
time_random_matrix(0,0,1000,1000,100);
time_gpu_random_matrix(0,1,1000,1000,100);
time_random_matrix(0,1,1000,1000,100);
time_gpu_random_matrix(1,0,1000,1000,100);
time_random_matrix(1,0,1000,1000,100);
time_gpu_random_matrix(1,1,1000,1000,100);
time_random_matrix(1,1,1000,1000,100);
}
#endif

@ -0,0 +1,121 @@
#include "mini_blas.h"
//From Berkeley Vision's Caffe!
//https://github.com/BVLC/caffe/blob/master/LICENSE
void im2col_cpu(float* data_im,
const int batch, const int channels, const int height, const int width,
const int ksize, const int stride, float* data_col)
{
int c,h,w,b;
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int channels_col = channels * ksize * ksize;
int im_size = height*width*channels;
int col_size = height_col*width_col*channels_col;
for(b = 0; b < batch; ++b){
for ( c = 0; c < channels_col; ++c) {
int w_offset = c % ksize;
int h_offset = (c / ksize) % ksize;
int c_im = c / ksize / ksize;
for ( h = 0; h < height_col; ++h) {
for ( w = 0; w < width_col; ++w) {
data_col[(c * height_col + h) * width_col + w] =
data_im[(c_im * height + h * stride + h_offset) * width
+ w * stride + w_offset];
}
}
}
data_im += im_size;
data_col+= col_size;
}
}
#ifdef GPU
#include "opencl.h"
#include <math.h>
cl_kernel get_im2col_kernel()
{
static int init = 0;
static cl_kernel im2col_kernel;
if(!init){
im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0);
init = 1;
}
return im2col_kernel;
}
void im2col_ongpu(cl_mem data_im, const int batch,
const int channels, const int height, const int width,
const int ksize, const int stride, cl_mem data_col)
{
cl_setup();
cl_kernel im2col_kernel = get_im2col_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride);
cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col);
check_error(cl);
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int channels_col = channels * ksize * ksize;
size_t global_size[2];
size_t local_size[2];
global_size[0] = batch;
global_size[1] = channels_col;
local_size[0] = height_col;
local_size[1] = width_col;
clEnqueueNDRangeKernel(queue, im2col_kernel, 2, 0,
global_size, local_size, 0, 0, 0);
check_error(cl);
}
void im2col_gpu(float *data_im,
const int batch, const int channels, const int height, const int width,
const int ksize, const int stride,
float *data_col)
{
cl_setup();
cl_context context = cl.context;
cl_command_queue queue = cl.queue;
size_t size = sizeof(float)*(channels*height*width*batch);
cl_mem im_gpu = clCreateBuffer(context,
CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
size, data_im, &cl.error);
check_error(cl);
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int channels_col = channels * ksize * ksize;
size = sizeof(float)*(height_col*width_col*channels_col*batch);
cl_mem col_gpu = clCreateBuffer(context,
CL_MEM_WRITE_ONLY|CL_MEM_COPY_HOST_PTR,
size, data_col, &cl.error);
check_error(cl);
im2col_ongpu(im_gpu, batch, channels, height, width,
ksize, stride, col_gpu);
clEnqueueReadBuffer(queue, col_gpu, CL_TRUE, 0, size, data_col, 0, 0, 0);
check_error(cl);
clReleaseMemObject(col_gpu);
clReleaseMemObject(im_gpu);
}
#endif

@ -0,0 +1,26 @@
__kernel void im2col(__global float *data_im,
const int batch, const int channels, const int height, const int width,
const int ksize, const int stride, __global float *data_col)
{
int b = get_global_id(0);
int c = get_global_id(1);
int h = get_local_id(0);
int w = get_local_id(1);
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int channels_col = channels * ksize * ksize;
int im_offset = height*width*channels*b;
int col_offset = height_col*width_col*channels_col*b;
int w_offset = c % ksize;
int h_offset = (c / ksize) % ksize;
int c_im = c / ksize / ksize;
data_col[(c * height_col + h) * width_col + w + col_offset] =
data_im[(c_im * height + h * stride + h_offset) * width
+ w * stride + w_offset + im_offset];
}

@ -1,3 +1,5 @@
#include "opencl.h"
void pm(int M, int N, float *A); void pm(int M, int N, float *A);
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda, float *A, int lda,
@ -6,15 +8,30 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
float *C, int ldc); float *C, int ldc);
float *random_matrix(int rows, int cols); float *random_matrix(int rows, int cols);
void time_random_matrix(int TA, int TB, int m, int k, int n); void time_random_matrix(int TA, int TB, int m, int k, int n);
void im2col_gpu(float* data_im, const int channels,
const int height, const int width, const int ksize, const int stride, #ifdef GPU
float* data_col); void im2col_ongpu(cl_mem data_im, const int batch,
void im2col_cpu(float* data_im, const int channels, const int channels, const int height, const int width,
const int height, const int width, const int ksize, const int stride, const int ksize, const int stride, cl_mem data_col);
float* data_col);
void im2col_gpu(float *data_im,
const int batch, const int channels, const int height, const int width,
const int ksize, const int stride, float *data_col);
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
cl_mem A_gpu, int lda,
cl_mem B_gpu, int ldb,
float BETA,
cl_mem C_gpu, int ldc);
#endif
void im2col_cpu(float* data_im,
const int batch, const int channels, const int height, const int width,
const int ksize, const int stride, float* data_col);
void col2im_cpu(float* data_col, const int channels, void col2im_cpu(float* data_col, const int channels,
const int height, const int width, const int ksize, const int stride, const int height, const int width, const int ksize, const int stride,
float* data_im); float* data_im);
void test_blas(); void test_blas();
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,

@ -19,6 +19,9 @@ network make_network(int n, int batch)
net.types = calloc(net.n, sizeof(LAYER_TYPE)); net.types = calloc(net.n, sizeof(LAYER_TYPE));
net.outputs = 0; net.outputs = 0;
net.output = 0; net.output = 0;
#ifdef GPU
net.input_cl = 0;
#endif
return net; return net;
} }
@ -40,17 +43,6 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, int first)
fprintf(fp, "data="); fprintf(fp, "data=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
/*
int j,k;
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
for(i = 0; i < l->n; ++i){
for(j = l->c-1; j >= 0; --j){
for(k = 0; k < l->size*l->size; ++k){
fprintf(fp, "%g,", l->filters[i*(l->c*l->size*l->size)+j*l->size*l->size+k]);
}
}
}
*/
fprintf(fp, "\n\n"); fprintf(fp, "\n\n");
} }
void print_connected_cfg(FILE *fp, connected_layer *l, int first) void print_connected_cfg(FILE *fp, connected_layer *l, int first)
@ -121,18 +113,34 @@ void save_network(network net, char *filename)
fclose(fp); fclose(fp);
} }
void forward_network(network net, float *input) void forward_network(network net, float *input, int train)
{ {
int i; int i;
#ifdef GPU
cl_setup();
size_t size = get_network_input_size(net);
if(!net.input_cl){
net.input_cl = clCreateBuffer(cl.context,
CL_MEM_READ_WRITE, size*sizeof(float), 0, &cl.error);
check_error(cl);
}
cl_write_array(net.input_cl, input, size);
cl_mem input_cl = net.input_cl;
#endif
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; convolutional_layer layer = *(convolutional_layer *)net.layers[i];
#ifdef GPU
forward_convolutional_layer_gpu(layer, input_cl);
input_cl = layer.output_cl;
#else
forward_convolutional_layer(layer, input); forward_convolutional_layer(layer, input);
#endif
input = layer.output; input = layer.output;
} }
else if(net.types[i] == CONNECTED){ else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i]; connected_layer layer = *(connected_layer *)net.layers[i];
forward_connected_layer(layer, input); forward_connected_layer(layer, input, train);
input = layer.output; input = layer.output;
} }
else if(net.types[i] == SOFTMAX){ else if(net.types[i] == SOFTMAX){
@ -263,9 +271,7 @@ float backward_network(network net, float *input, float *truth)
} }
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; convolutional_layer layer = *(convolutional_layer *)net.layers[i];
learn_convolutional_layer(layer); backward_convolutional_layer(layer, prev_delta);
//learn_convolutional_layer(layer);
if(i != 0) backward_convolutional_layer(layer, prev_delta);
} }
else if(net.types[i] == MAXPOOL){ else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i]; maxpool_layer layer = *(maxpool_layer *)net.layers[i];
@ -281,8 +287,7 @@ float backward_network(network net, float *input, float *truth)
} }
else if(net.types[i] == CONNECTED){ else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i]; connected_layer layer = *(connected_layer *)net.layers[i];
learn_connected_layer(layer, prev_input); backward_connected_layer(layer, prev_input, prev_delta);
if(i != 0) backward_connected_layer(layer, prev_input, prev_delta);
} }
} }
return error; return error;
@ -290,7 +295,7 @@ float backward_network(network net, float *input, float *truth)
float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay) float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay)
{ {
forward_network(net, x); forward_network(net, x, 1);
//int class = get_predicted_class_network(net); //int class = get_predicted_class_network(net);
float error = backward_network(net, x, y); float error = backward_network(net, x, y);
update_network(net, step, momentum, decay); update_network(net, step, momentum, decay);
@ -332,7 +337,7 @@ float train_network_batch(network net, data d, int n, float step, float momentum
int index = rand()%d.X.rows; int index = rand()%d.X.rows;
float *x = d.X.vals[index]; float *x = d.X.vals[index];
float *y = d.y.vals[index]; float *y = d.y.vals[index];
forward_network(net, x); forward_network(net, x, 1);
int class = get_predicted_class_network(net); int class = get_predicted_class_network(net);
backward_network(net, x, y); backward_network(net, x, y);
correct += (y[class]?1:0); correct += (y[class]?1:0);
@ -359,6 +364,27 @@ void train_network(network net, data d, float step, float momentum, float decay)
fprintf(stderr, "Accuracy: %f\n", (float)correct/d.X.rows); fprintf(stderr, "Accuracy: %f\n", (float)correct/d.X.rows);
} }
int get_network_input_size_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
return layer.h*layer.w*layer.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return layer.h*layer.w*layer.c;
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.inputs;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.inputs;
}
return 0;
}
int get_network_output_size_layer(network net, int i) int get_network_output_size_layer(network net, int i)
{ {
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){
@ -382,36 +408,6 @@ int get_network_output_size_layer(network net, int i)
return 0; return 0;
} }
/*
int resize_network(network net, int h, int w, int c)
{
int i;
for (i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer *layer = (convolutional_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_convolutional_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_maxpool_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
}
return 0;
}
*/
int resize_network(network net, int h, int w, int c) int resize_network(network net, int h, int w, int c)
{ {
int i; int i;
@ -450,6 +446,11 @@ int get_network_output_size(network net)
return get_network_output_size_layer(net, i); return get_network_output_size_layer(net, i);
} }
int get_network_input_size(network net)
{
return get_network_output_size_layer(net, 0);
}
image get_network_image_layer(network net, int i) image get_network_image_layer(network net, int i)
{ {
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){
@ -497,7 +498,7 @@ void visualize_network(network net)
float *network_predict(network net, float *input) float *network_predict(network net, float *input)
{ {
forward_network(net, input); forward_network(net, input, 0);
float *out = get_network_output(net); float *out = get_network_output(net);
return out; return out;
} }

@ -2,6 +2,7 @@
#ifndef NETWORK_H #ifndef NETWORK_H
#define NETWORK_H #define NETWORK_H
#include "opencl.h"
#include "image.h" #include "image.h"
#include "data.h" #include "data.h"
@ -20,10 +21,15 @@ typedef struct {
LAYER_TYPE *types; LAYER_TYPE *types;
int outputs; int outputs;
float *output; float *output;
#ifdef GPU
cl_mem input_cl;
cl_mem output_cl;
#endif
} network; } network;
network make_network(int n, int batch); network make_network(int n, int batch);
void forward_network(network net, float *input); void forward_network(network net, float *input, int train);
float backward_network(network net, float *input, float *truth); float backward_network(network net, float *input, float *truth);
void update_network(network net, float step, float momentum, float decay); void update_network(network net, float step, float momentum, float decay);
float train_network_sgd(network net, data d, int n, float step, float momentum,float decay); float train_network_sgd(network net, data d, int n, float step, float momentum,float decay);
@ -44,6 +50,7 @@ void print_network(network net);
void visualize_network(network net); void visualize_network(network net);
void save_network(network net, char *filename); void save_network(network net, char *filename);
int resize_network(network net, int h, int w, int c); int resize_network(network net, int h, int w, int c);
int get_network_input_size(network net);
#endif #endif

@ -88,4 +88,18 @@ cl_kernel get_kernel(char *filename, char *kernelname, char *options)
return kernel; return kernel;
} }
void cl_read_array(cl_mem mem, float *x, int n)
{
cl_setup();
clEnqueueReadBuffer(cl.queue, mem, CL_TRUE, 0, sizeof(float)*n,x,0,0,0);
check_error(cl);
}
void cl_write_array(cl_mem mem, float *x, int n)
{
cl_setup();
clEnqueueWriteBuffer(cl.queue, mem, CL_TRUE, 0,sizeof(float)*n,x,0,0,0);
check_error(cl);
}
#endif #endif

@ -1,3 +1,6 @@
#ifdef GPU
#ifndef OPENCL_H
#define OPENCL_H
#ifdef __APPLE__ #ifdef __APPLE__
#include <OpenCL/opencl.h> #include <OpenCL/opencl.h>
#else #else
@ -18,4 +21,7 @@ extern cl_info cl;
void cl_setup(); void cl_setup();
void check_error(cl_info info); void check_error(cl_info info);
cl_kernel get_kernel(char *filename, char *kernelname, char *options); cl_kernel get_kernel(char *filename, char *kernelname, char *options);
void cl_read_array(cl_mem mem, float *x, int n);
void cl_write_array(cl_mem mem, float *x, int n);
#endif
#endif

@ -89,6 +89,7 @@ connected_layer *parse_connected(list *options, network net, int count)
int i; int i;
int input; int input;
int output = option_find_int(options, "output",1); int output = option_find_int(options, "output",1);
float dropout = option_find_float(options, "dropout", 0.);
char *activation_s = option_find_str(options, "activation", "sigmoid"); char *activation_s = option_find_str(options, "activation", "sigmoid");
ACTIVATION activation = get_activation(activation_s); ACTIVATION activation = get_activation(activation_s);
if(count == 0){ if(count == 0){
@ -97,7 +98,7 @@ connected_layer *parse_connected(list *options, network net, int count)
}else{ }else{
input = get_network_output_size_layer(net, count-1); input = get_network_output_size_layer(net, count-1);
} }
connected_layer *layer = make_connected_layer(net.batch, input, output, activation); connected_layer *layer = make_connected_layer(net.batch, input, output, dropout, activation);
char *data = option_find_str(options, "data", 0); char *data = option_find_str(options, "data", 0);
if(data){ if(data){
char *curr = data; char *curr = data;

@ -52,7 +52,7 @@ void test_convolve_matrix()
int i; int i;
clock_t start = clock(), end; clock_t start = clock(), end;
for(i = 0; i < 1000; ++i){ for(i = 0; i < 1000; ++i){
im2col_cpu(dog.data, dog.c, dog.h, dog.w, size, stride, matrix); im2col_cpu(dog.data, 1, dog.c, dog.h, dog.w, size, stride, matrix);
gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
} }
end = clock(); end = clock();
@ -168,7 +168,7 @@ void test_parser()
float v = ((float)rand()/RAND_MAX); float v = ((float)rand()/RAND_MAX);
float truth = v*v; float truth = v*v;
input[0] = v; input[0] = v;
forward_network(net, input); forward_network(net, input, 1);
float *out = get_network_output(net); float *out = get_network_output(net);
float *delta = get_network_delta(net); float *delta = get_network_delta(net);
float err = pow((out[0]-truth),2.); float err = pow((out[0]-truth),2.);
@ -245,7 +245,7 @@ void test_full()
normalize_data_rows(test); normalize_data_rows(test);
for(j = 0; j < test.X.rows; ++j){ for(j = 0; j < test.X.rows; ++j){
float *x = test.X.vals[j]; float *x = test.X.vals[j];
forward_network(net, x); forward_network(net, x, 0);
int class = get_predicted_class_network(net); int class = get_predicted_class_network(net);
fprintf(fp, "%d\n", class); fprintf(fp, "%d\n", class);
} }
@ -317,21 +317,13 @@ void test_nist()
int batch = 10000; int batch = 10000;
while(++count <= 10000){ while(++count <= 10000){
float loss = train_network_sgd(net, train, batch, lr, momentum, decay); float loss = train_network_sgd(net, train, batch, lr, momentum, decay);
printf("%5f %5f\n",(double)count*batch/train.X.rows, loss); float test_acc = network_accuracy(net, test);
printf("%3d %5f %5f\n",count, loss, test_acc);
//printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay); //printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay);
//end = clock(); //end = clock();
//printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); //printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
//start=end; //start=end;
/*
if(count%5 == 0){
float train_acc = network_accuracy(net, train);
fprintf(stderr, "\nTRAIN: %f\n", train_acc);
float test_acc = network_accuracy(net, test);
fprintf(stderr, "TEST: %f\n\n", test_acc);
printf("%d, %f, %f\n", count, train_acc, test_acc);
//lr *= .5; //lr *= .5;
}
*/
} }
} }
@ -387,7 +379,7 @@ void test_random_classify()
int index = rand()%m.rows; int index = rand()%m.rows;
//image p = float_to_image(1690,1,1,m.vals[index]); //image p = float_to_image(1690,1,1,m.vals[index]);
//normalize_image(p); //normalize_image(p);
forward_network(net, m.vals[index]); forward_network(net, m.vals[index], 1);
float *out = get_network_output(net); float *out = get_network_output(net);
float *delta = get_network_delta(net); float *delta = get_network_delta(net);
//printf("%f\n", out[0]); //printf("%f\n", out[0]);
@ -408,7 +400,7 @@ void test_random_classify()
matrix test = csv_to_matrix("test.csv"); matrix test = csv_to_matrix("test.csv");
truth = pop_column(&test, 0); truth = pop_column(&test, 0);
for(i = 0; i < test.rows; ++i){ for(i = 0; i < test.rows; ++i){
forward_network(net, test.vals[i]); forward_network(net, test.vals[i], 0);
float *out = get_network_output(net); float *out = get_network_output(net);
if(fabs(out[0]) < .5) fprintf(fp, "0\n"); if(fabs(out[0]) < .5) fprintf(fp, "0\n");
else fprintf(fp, "1\n"); else fprintf(fp, "1\n");
@ -439,7 +431,7 @@ void test_im2row()
float *matrix = calloc(msize, sizeof(float)); float *matrix = calloc(msize, sizeof(float));
int i; int i;
for(i = 0; i < 1000; ++i){ for(i = 0; i < 1000; ++i){
im2col_cpu(test.data, c, h, w, size, stride, matrix); im2col_cpu(test.data, 1, c, h, w, size, stride, matrix);
//image render = float_to_image(mh, mw, mc, matrix); //image render = float_to_image(mh, mw, mc, matrix);
} }
} }
@ -506,7 +498,7 @@ image features_output_size(network net, IplImage *src, int outh, int outw)
//normalize_array(im.data, im.h*im.w*im.c); //normalize_array(im.data, im.h*im.w*im.c);
translate_image(im, -144); translate_image(im, -144);
resize_network(net, im.h, im.w, im.c); resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data); forward_network(net, im.data, 0);
image out = get_network_image(net); image out = get_network_image(net);
free_image(im); free_image(im);
cvReleaseImage(&sized); cvReleaseImage(&sized);
@ -558,7 +550,7 @@ void visualize_imagenet_topk(char *filename)
resize_network(net, im.h, im.w, im.c); resize_network(net, im.h, im.w, im.c);
//scale_image(im, 1./255); //scale_image(im, 1./255);
translate_image(im, -144); translate_image(im, -144);
forward_network(net, im.data); forward_network(net, im.data, 0);
image out = get_network_image(net); image out = get_network_image(net);
int dh = (im.h - h)/(out.h-1); int dh = (im.h - h)/(out.h-1);
@ -620,7 +612,7 @@ void visualize_imagenet_features(char *filename)
image im = load_image(image_path, 0, 0); image im = load_image(image_path, 0, 0);
printf("Processing %dx%d image\n", im.h, im.w); printf("Processing %dx%d image\n", im.h, im.w);
resize_network(net, im.h, im.w, im.c); resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data); forward_network(net, im.data, 0);
image out = get_network_image(net); image out = get_network_image(net);
int dh = (im.h - h)/h; int dh = (im.h - h)/h;
@ -653,7 +645,7 @@ void visualize_cat()
image im = load_image("data/cat.png", 0, 0); image im = load_image("data/cat.png", 0, 0);
printf("Processing %dx%d image\n", im.h, im.w); printf("Processing %dx%d image\n", im.h, im.w);
resize_network(net, im.h, im.w, im.c); resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data); forward_network(net, im.data, 0);
visualize_network(net); visualize_network(net);
cvWaitKey(0); cvWaitKey(0);

Loading…
Cancel
Save