@@ -39,9 +39,9 @@ function instantiate_function(
39
39
adtype, soadtype = generate_adtype (adtype)
40
40
41
41
if g == true && f. grad === nothing
42
- extras_grad = prepare_gradient (_f, adtype, x)
42
+ prep_grad = prepare_gradient (_f, adtype, x)
43
43
function grad (res, θ)
44
- gradient! (_f, res, adtype, θ, extras_grad )
44
+ gradient! (_f, res, prep_grad, adtype, θ )
45
45
end
46
46
if p != = SciMLBase. NullParameters () && p != = nothing
47
47
function grad (res, θ, p)
@@ -57,10 +57,10 @@ function instantiate_function(
57
57
58
58
if fg == true && f. fg === nothing
59
59
if g == false
60
- extras_grad = prepare_gradient (_f, adtype, x)
60
+ prep_grad = prepare_gradient (_f, adtype, x)
61
61
end
62
62
function fg! (res, θ)
63
- (y, _) = value_and_gradient! (_f, res, adtype, θ, extras_grad )
63
+ (y, _) = value_and_gradient! (_f, res, prep_grad, adtype, θ )
64
64
return y
65
65
end
66
66
if p != = SciMLBase. NullParameters () && p != = nothing
@@ -79,9 +79,9 @@ function instantiate_function(
79
79
hess_sparsity = f. hess_prototype
80
80
hess_colors = f. hess_colorvec
81
81
if h == true && f. hess === nothing
82
- extras_hess = prepare_hessian (_f, soadtype, x)
82
+ prep_hess = prepare_hessian (_f, soadtype, x)
83
83
function hess (res, θ)
84
- hessian! (_f, res, soadtype, θ, extras_hess )
84
+ hessian! (_f, res, prep_hess, soadtype, θ )
85
85
end
86
86
if p != = SciMLBase. NullParameters () && p != = nothing
87
87
function hess (res, θ, p)
@@ -98,7 +98,7 @@ function instantiate_function(
98
98
if fgh == true && f. fgh === nothing
99
99
function fgh! (G, H, θ)
100
100
(y, _, _) = value_derivative_and_second_derivative! (
101
- _f, G, H, soadtype, θ, extras_hess )
101
+ _f, G, H, prep_hess, soadtype, θ )
102
102
return y
103
103
end
104
104
if p != = SciMLBase. NullParameters () && p != = nothing
@@ -116,14 +116,14 @@ function instantiate_function(
116
116
end
117
117
118
118
if hv == true && f. hv === nothing
119
- extras_hvp = prepare_hvp (_f, soadtype, x, zeros (eltype (x), size (x)))
119
+ prep_hvp = prepare_hvp (_f, soadtype, x, ( zeros (eltype (x), size (x)), ))
120
120
function hv! (H, θ, v)
121
- hvp! (_f, H, soadtype, θ, v, extras_hvp )
121
+ only ( hvp! (_f, (H,), prep_hvp, soadtype, θ, (v,)) )
122
122
end
123
123
if p != = SciMLBase. NullParameters () && p != = nothing
124
124
function hv! (H, θ, v, p)
125
125
global _p = p
126
- hvp! (_f, H, soadtype, θ, v )
126
+ only ( hvp! (_f, (H,), soadtype, θ, (v,)) )
127
127
end
128
128
end
129
129
elseif hv == true
@@ -156,9 +156,9 @@ function instantiate_function(
156
156
cons_jac_prototype = f. cons_jac_prototype
157
157
cons_jac_colorvec = f. cons_jac_colorvec
158
158
if cons != = nothing && cons_j == true && f. cons_j === nothing
159
- extras_jac = prepare_jacobian (cons_oop, adtype, x)
159
+ prep_jac = prepare_jacobian (cons_oop, adtype, x)
160
160
function cons_j! (J, θ)
161
- jacobian! (cons_oop, J, adtype, θ, extras_jac )
161
+ jacobian! (cons_oop, J, prep_jac, adtype, θ )
162
162
if size (J, 1 ) == 1
163
163
J = vec (J)
164
164
end
@@ -170,9 +170,9 @@ function instantiate_function(
170
170
end
171
171
172
172
if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
173
- extras_pullback = prepare_pullback (cons_oop, adtype, x, ones (eltype (x), num_cons))
173
+ prep_pullback = prepare_pullback (cons_oop, adtype, x, ( ones (eltype (x), num_cons), ))
174
174
function cons_vjp! (J, θ, v)
175
- pullback! (cons_oop, J, adtype, θ, v, extras_pullback )
175
+ only ( pullback! (cons_oop, (J,), prep_pullback, adtype, θ, (v,)) )
176
176
end
177
177
elseif cons_vjp == true && cons != = nothing
178
178
cons_vjp! = (J, θ, v) -> f. cons_vjp (J, θ, v, p)
@@ -181,10 +181,10 @@ function instantiate_function(
181
181
end
182
182
183
183
if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
184
- extras_pushforward = prepare_pushforward (
185
- cons_oop, adtype, x, ones (eltype (x), length (x)))
184
+ prep_pushforward = prepare_pushforward (
185
+ cons_oop, adtype, x, ( ones (eltype (x), length (x)), ))
186
186
function cons_jvp! (J, θ, v)
187
- pushforward! (cons_oop, J, adtype, θ, v, extras_pushforward )
187
+ only ( pushforward! (cons_oop, (J,), prep_pushforward, adtype, θ, (v,)) )
188
188
end
189
189
elseif cons_jvp == true && cons != = nothing
190
190
cons_jvp! = (J, θ, v) -> f. cons_jvp (J, θ, v, p)
@@ -196,11 +196,11 @@ function instantiate_function(
196
196
conshess_colors = f. cons_hess_colorvec
197
197
if cons != = nothing && f. cons_h === nothing && cons_h == true
198
198
fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
199
- extras_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
199
+ prep_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
200
200
201
201
function cons_h! (H, θ)
202
202
for i in 1 : num_cons
203
- hessian! (fncs[i], H[i], soadtype, θ, extras_cons_hess[i] )
203
+ hessian! (fncs[i], H[i], prep_cons_hess[i], soadtype, θ )
204
204
end
205
205
end
206
206
elseif cons_h == true && cons != = nothing
@@ -212,7 +212,7 @@ function instantiate_function(
212
212
lag_hess_prototype = f. lag_hess_prototype
213
213
214
214
if cons != = nothing && lag_h == true && f. lag_h === nothing
215
- lag_extras = prepare_hessian (
215
+ lag_prep = prepare_hessian (
216
216
lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
217
217
lag_hess_prototype = zeros (Bool, length (x) + num_cons + 1 , length (x) + num_cons + 1 )
218
218
@@ -221,13 +221,13 @@ function instantiate_function(
221
221
cons_h (H, θ)
222
222
H *= λ
223
223
else
224
- H .= @view (hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
224
+ H .= @view (hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
225
225
1 : length (θ), 1 : length (θ)])
226
226
end
227
227
end
228
228
229
229
function lag_h! (h:: AbstractVector , θ, σ, λ)
230
- H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )
230
+ H = hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))
231
231
k = 0
232
232
for i in 1 : length (θ)
233
233
for j in 1 : i
@@ -244,14 +244,14 @@ function instantiate_function(
244
244
H *= λ
245
245
else
246
246
global _p = p
247
- H .= @view (hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
247
+ H .= @view (hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
248
248
1 : length (θ), 1 : length (θ)])
249
249
end
250
250
end
251
251
252
252
function lag_h! (h:: AbstractVector , θ, σ, λ, p)
253
253
global _p = p
254
- H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )
254
+ H = hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))
255
255
k = 0
256
256
for i in 1 : length (θ)
257
257
for j in 1 : i
@@ -308,9 +308,9 @@ function instantiate_function(
308
308
adtype, soadtype = generate_adtype (adtype)
309
309
310
310
if g == true && f. grad === nothing
311
- extras_grad = prepare_gradient (_f, adtype, x)
311
+ prep_grad = prepare_gradient (_f, adtype, x)
312
312
function grad (θ)
313
- gradient (_f, adtype, θ, extras_grad )
313
+ gradient (_f, prep_grad, adtype, θ )
314
314
end
315
315
if p != = SciMLBase. NullParameters () && p != = nothing
316
316
function grad (θ, p)
@@ -326,10 +326,10 @@ function instantiate_function(
326
326
327
327
if fg == true && f. fg === nothing
328
328
if g == false
329
- extras_grad = prepare_gradient (_f, adtype, x)
329
+ prep_grad = prepare_gradient (_f, adtype, x)
330
330
end
331
331
function fg! (θ)
332
- (y, res) = value_and_gradient (_f, adtype, θ, extras_grad )
332
+ (y, res) = value_and_gradient (_f, prep_grad, adtype, θ )
333
333
return y, res
334
334
end
335
335
if p != = SciMLBase. NullParameters () && p != = nothing
@@ -348,9 +348,9 @@ function instantiate_function(
348
348
hess_sparsity = f. hess_prototype
349
349
hess_colors = f. hess_colorvec
350
350
if h == true && f. hess === nothing
351
- extras_hess = prepare_hessian (_f, soadtype, x)
351
+ prep_hess = prepare_hessian (_f, soadtype, x)
352
352
function hess (θ)
353
- hessian (_f, soadtype, θ, extras_hess )
353
+ hessian (_f, prep_hess, soadtype, θ )
354
354
end
355
355
if p != = SciMLBase. NullParameters () && p != = nothing
356
356
function hess (θ, p)
@@ -366,7 +366,7 @@ function instantiate_function(
366
366
367
367
if fgh == true && f. fgh === nothing
368
368
function fgh! (θ)
369
- (y, G, H) = value_derivative_and_second_derivative (_f, adtype, θ, extras_hess )
369
+ (y, G, H) = value_derivative_and_second_derivative (_f, prep_hess, adtype, θ )
370
370
return y, G, H
371
371
end
372
372
if p != = SciMLBase. NullParameters () && p != = nothing
@@ -383,14 +383,14 @@ function instantiate_function(
383
383
end
384
384
385
385
if hv == true && f. hv === nothing
386
- extras_hvp = prepare_hvp (_f, soadtype, x, zeros (eltype (x), size (x)))
386
+ prep_hvp = prepare_hvp (_f, soadtype, x, ( zeros (eltype (x), size (x)), ))
387
387
function hv! (θ, v)
388
- hvp (_f, soadtype, θ, v, extras_hvp )
388
+ only ( hvp (_f, prep_hvp, soadtype, θ, (v)) )
389
389
end
390
390
if p != = SciMLBase. NullParameters () && p != = nothing
391
391
function hv! (θ, v, p)
392
392
global _p = p
393
- hvp ( _f, soadtype, θ, v, extras_hvp )
393
+ only ( vp ( _f, prep_hvp, soadtype, θ, (v,)) )
394
394
end
395
395
end
396
396
elseif hv == true
@@ -417,9 +417,9 @@ function instantiate_function(
417
417
cons_jac_prototype = f. cons_jac_prototype
418
418
cons_jac_colorvec = f. cons_jac_colorvec
419
419
if cons != = nothing && cons_j == true && f. cons_j === nothing
420
- extras_jac = prepare_jacobian (cons, adtype, x)
420
+ prep_jac = prepare_jacobian (cons, adtype, x)
421
421
function cons_j! (θ)
422
- J = jacobian (cons, adtype, θ, extras_jac )
422
+ J = jacobian (cons, prep_jac, adtype, θ )
423
423
if size (J, 1 ) == 1
424
424
J = vec (J)
425
425
end
@@ -432,9 +432,9 @@ function instantiate_function(
432
432
end
433
433
434
434
if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
435
- extras_pullback = prepare_pullback (cons, adtype, x, ones (eltype (x), num_cons))
435
+ prep_pullback = prepare_pullback (cons, adtype, x, ( ones (eltype (x), num_cons), ))
436
436
function cons_vjp! (θ, v)
437
- return pullback (cons, adtype, θ, v, extras_pullback )
437
+ return only ( pullback (cons, prep_pullback, adtype, θ, (v,)) )
438
438
end
439
439
elseif cons_vjp == true && cons != = nothing
440
440
cons_vjp! = (θ, v) -> f. cons_vjp (θ, v, p)
@@ -443,10 +443,10 @@ function instantiate_function(
443
443
end
444
444
445
445
if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
446
- extras_pushforward = prepare_pushforward (
447
- cons, adtype, x, ones (eltype (x), length (x)))
446
+ prep_pushforward = prepare_pushforward (
447
+ cons, adtype, x, ( ones (eltype (x), length (x)), ))
448
448
function cons_jvp! (θ, v)
449
- return pushforward (cons, adtype, θ, v, extras_pushforward )
449
+ return only ( pushforward (cons, prep_pushforward, adtype, θ, (v,)) )
450
450
end
451
451
elseif cons_jvp == true && cons != = nothing
452
452
cons_jvp! = (θ, v) -> f. cons_jvp (θ, v, p)
@@ -458,11 +458,11 @@ function instantiate_function(
458
458
conshess_colors = f. cons_hess_colorvec
459
459
if cons != = nothing && cons_h == true && f. cons_h === nothing
460
460
fncs = [(x) -> cons (x)[i] for i in 1 : num_cons]
461
- extras_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
461
+ prep_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
462
462
463
463
function cons_h! (θ)
464
464
H = map (1 : num_cons) do i
465
- hessian (fncs[i], soadtype, θ, extras_cons_hess[i] )
465
+ hessian (fncs[i], prep_cons_hess[i], soadtype, θ )
466
466
end
467
467
return H
468
468
end
@@ -475,15 +475,15 @@ function instantiate_function(
475
475
lag_hess_prototype = f. lag_hess_prototype
476
476
477
477
if cons != = nothing && lag_h == true && f. lag_h === nothing
478
- lag_extras = prepare_hessian (
478
+ lag_prep = prepare_hessian (
479
479
lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
480
480
lag_hess_prototype = zeros (Bool, length (x) + num_cons + 1 , length (x) + num_cons + 1 )
481
481
482
482
function lag_h! (θ, σ, λ)
483
483
if σ == zero (eltype (θ))
484
484
return λ .* cons_h (θ)
485
485
else
486
- return hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
486
+ return hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
487
487
1 : length (θ), 1 : length (θ)]
488
488
end
489
489
end
@@ -494,7 +494,7 @@ function instantiate_function(
494
494
return λ .* cons_h (θ)
495
495
else
496
496
global _p = p
497
- return hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
497
+ return hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
498
498
1 : length (θ), 1 : length (θ)]
499
499
end
500
500
end
0 commit comments