|
|
@ -352,7 +352,7 @@ float train_network_datum_gpu(network net, float *x, float *y) |
|
|
|
if (net.adversarial_lr && rand_int(0, 1) == 1 && get_current_iteration(net) > net.burn_in) { |
|
|
|
if (net.adversarial_lr && rand_int(0, 1) == 1 && get_current_iteration(net) > net.burn_in) { |
|
|
|
net.adversarial = 1; |
|
|
|
net.adversarial = 1; |
|
|
|
float lr_old = net.learning_rate; |
|
|
|
float lr_old = net.learning_rate; |
|
|
|
float scale = ((float)net.max_batches) / get_current_iteration(net); |
|
|
|
float scale = 1.0 - (get_current_iteration(net) / ((float)net.max_batches)); |
|
|
|
net.learning_rate = net.adversarial_lr * scale; |
|
|
|
net.learning_rate = net.adversarial_lr * scale; |
|
|
|
layer l = net.layers[net.n - 1]; |
|
|
|
layer l = net.layers[net.n - 1]; |
|
|
|
int y_size = get_network_output_size(net)*net.batch; |
|
|
|
int y_size = get_network_output_size(net)*net.batch; |
|
|
|