|
|
|
@ -176,7 +176,7 @@ void flip_board(float *board) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void test_go(char *filename, char *weightfile) |
|
|
|
|
void test_go(char *filename, char *weightfile, int multi) |
|
|
|
|
{ |
|
|
|
|
network net = parse_network_cfg(filename); |
|
|
|
|
if(weightfile){ |
|
|
|
@ -191,25 +191,25 @@ void test_go(char *filename, char *weightfile) |
|
|
|
|
float *output = network_predict(net, board); |
|
|
|
|
copy_cpu(19*19, output, 1, move, 1); |
|
|
|
|
int i; |
|
|
|
|
#ifdef GPU |
|
|
|
|
image bim = float_to_image(19, 19, 1, board); |
|
|
|
|
for(i = 1; i < 8; ++i){ |
|
|
|
|
rotate_image_cw(bim, i); |
|
|
|
|
if(i >= 4) flip_image(bim); |
|
|
|
|
if(multi){ |
|
|
|
|
image bim = float_to_image(19, 19, 1, board); |
|
|
|
|
for(i = 1; i < 8; ++i){ |
|
|
|
|
rotate_image_cw(bim, i); |
|
|
|
|
if(i >= 4) flip_image(bim); |
|
|
|
|
|
|
|
|
|
float *output = network_predict(net, board); |
|
|
|
|
image oim = float_to_image(19, 19, 1, output); |
|
|
|
|
float *output = network_predict(net, board); |
|
|
|
|
image oim = float_to_image(19, 19, 1, output); |
|
|
|
|
|
|
|
|
|
if(i >= 4) flip_image(oim); |
|
|
|
|
rotate_image_cw(oim, -i); |
|
|
|
|
if(i >= 4) flip_image(oim); |
|
|
|
|
rotate_image_cw(oim, -i); |
|
|
|
|
|
|
|
|
|
axpy_cpu(19*19, 1, output, 1, move, 1); |
|
|
|
|
axpy_cpu(19*19, 1, output, 1, move, 1); |
|
|
|
|
|
|
|
|
|
if(i >= 4) flip_image(bim); |
|
|
|
|
rotate_image_cw(bim, -i); |
|
|
|
|
if(i >= 4) flip_image(bim); |
|
|
|
|
rotate_image_cw(bim, -i); |
|
|
|
|
} |
|
|
|
|
scal_cpu(19*19, 1./8., move, 1); |
|
|
|
|
} |
|
|
|
|
scal_cpu(19*19, 1./8., move, 1); |
|
|
|
|
#endif |
|
|
|
|
for(i = 0; i < 19*19; ++i){ |
|
|
|
|
if(board[i]) move[i] = 0; |
|
|
|
|
} |
|
|
|
@ -282,8 +282,9 @@ void run_go(int argc, char **argv) |
|
|
|
|
|
|
|
|
|
char *cfg = argv[3]; |
|
|
|
|
char *weights = (argc > 4) ? argv[4] : 0; |
|
|
|
|
int multi = find_arg(argc, argv, "-multi"); |
|
|
|
|
if(0==strcmp(argv[2], "train")) train_go(cfg, weights); |
|
|
|
|
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights); |
|
|
|
|
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|