diff --git a/darknet.py b/darknet.py index 3d885d20..5d0fe14a 100644 --- a/darknet.py +++ b/darknet.py @@ -125,7 +125,7 @@ lib.network_width.restype = c_int lib.network_height.argtypes = [c_void_p] lib.network_height.restype = c_int -predict = lib.network_predict +predict = lib.network_predict_ptr predict.argtypes = [c_void_p, POINTER(c_float)] predict.restype = POINTER(c_float) diff --git a/src/network.c b/src/network.c index d6bd820b..ef55f7f1 100644 --- a/src/network.c +++ b/src/network.c @@ -556,6 +556,12 @@ void top_predictions(network net, int k, int *index) top_k(out, size, k, index); } +// A version of network_predict that uses a pointer for the network +// struct to make the python binding work properly. +float *network_predict_ptr(network *net, float *input) +{ + return network_predict(*net, input); +} float *network_predict(network net, float *input) { diff --git a/src/network.h b/src/network.h index f71c56cb..4247efb0 100644 --- a/src/network.h +++ b/src/network.h @@ -122,6 +122,7 @@ float train_network_datum(network net, float *x, float *y); matrix network_predict_data(network net, data test); //LIB_API float *network_predict(network net, float *input); +LIB_API float *network_predict_ptr(network *net, float *input); float network_accuracy(network net, data d); float *network_accuracies(network net, data d, int n); float network_accuracy_multi(network net, data d, int n);