@@ -17,8 +17,8 @@ using Core: Vararg
17
17
end
18
18
end
19
19
20
- function inner_grad (θ, bθ, f, p)
21
- Enzyme. autodiff_deferred (Enzyme . Reverse ,
20
+ function inner_grad (mode :: Mode , θ, bθ, f, p) where {Mode}
21
+ Enzyme. autodiff (mode ,
22
22
Const (firstapply),
23
23
Active,
24
24
Const (f),
@@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
28
28
return nothing
29
29
end
30
30
31
- function inner_grad_primal (θ, bθ, f, p)
32
- Enzyme. autodiff_deferred (Enzyme. ReverseWithPrimal,
33
- Const (firstapply),
34
- Active,
35
- Const (f),
36
- Enzyme. Duplicated (θ, bθ),
37
- Const (p)
38
- )[2 ]
39
- end
40
-
41
- function hv_f2_alloc (x, f, p)
31
+ function hv_f2_alloc (mode:: Mode , x, f, p) where {Mode}
42
32
dx = Enzyme. make_zero (x)
43
- Enzyme. autodiff_deferred (Enzyme . Reverse ,
33
+ Enzyme. autodiff (mode ,
44
34
Const (firstapply),
45
35
Active,
46
36
Const (f),
@@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
57
47
return res[i]
58
48
end
59
49
60
- function cons_f2 (x, dx, fcons, p, num_cons, i)
50
+ function cons_f2 (mode, x, dx, fcons, p, num_cons, i)
61
51
Enzyme. autodiff_deferred (
62
- Enzyme . Reverse , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
52
+ mode , Const (inner_cons), Active, Enzyme. Duplicated (x, dx),
63
53
Const (fcons), Const (p), Const (num_cons), Const (i))
64
54
return nothing
65
55
end
@@ -70,9 +60,9 @@ function inner_cons_oop(
70
60
return fcons (x, p)[i]
71
61
end
72
62
73
- function cons_f2_oop (x, dx, fcons, p, i)
63
+ function cons_f2_oop (mode, x, dx, fcons, p, i)
74
64
Enzyme. autodiff_deferred (
75
- Enzyme . Reverse , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
65
+ mode , Const (inner_cons_oop), Active, Enzyme. Duplicated (x, dx),
76
66
Const (fcons), Const (p), Const (i))
77
67
return nothing
78
68
end
@@ -83,22 +73,38 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
83
73
return σ * _f (x, p) + dot (λ, res)
84
74
end
85
75
86
- function lag_grad (x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
76
+ function lag_grad (mode, x, dx, lagrangian:: Function , _f:: Function , cons:: Function , p, σ, λ)
87
77
Enzyme. autodiff_deferred (
88
- Enzyme . Reverse , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
78
+ mode , Const (lagrangian), Active, Enzyme. Duplicated (x, dx),
89
79
Const (_f), Const (cons), Const (p), Const (λ), Const (σ))
90
80
return nothing
91
81
end
92
82
83
+ function set_runtime_activity2 (
84
+ a:: Mode1 , :: Enzyme.Mode{ABI, Err, RTA} ) where {Mode1, ABI, Err, RTA}
85
+ Enzyme. set_runtime_activity (a, RTA)
86
+ end
93
87
function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
94
88
adtype:: AutoEnzyme , p, num_cons = 0 ;
95
89
g = false , h = false , hv = false , fg = false , fgh = false ,
96
90
cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
97
91
lag_h = false )
92
+ rmode = if adtype. mode isa Nothing
93
+ Enzyme. Reverse
94
+ else
95
+ set_runtime_activity2 (Enzyme. Reverse, adtype. mode)
96
+ end
97
+
98
+ fmode = if adtype. mode isa Nothing
99
+ Enzyme. Forward
100
+ else
101
+ set_runtime_activity2 (Enzyme. Forward, adtype. mode)
102
+ end
103
+
98
104
if g == true && f. grad === nothing
99
105
function grad (res, θ, p = p)
100
106
Enzyme. make_zero! (res)
101
- Enzyme. autodiff (Enzyme . Reverse ,
107
+ Enzyme. autodiff (rmode ,
102
108
Const (firstapply),
103
109
Active,
104
110
Const (f. f),
@@ -115,7 +121,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
115
121
if fg == true && f. fg === nothing
116
122
function fg! (res, θ, p = p)
117
123
Enzyme. make_zero! (res)
118
- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
124
+ y = Enzyme. autodiff (WithPrimal (rmode) ,
119
125
Const (firstapply),
120
126
Active,
121
127
Const (f. f),
@@ -145,8 +151,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
145
151
Enzyme. make_zero! (bθ)
146
152
Enzyme. make_zero! .(vdbθ)
147
153
148
- Enzyme. autodiff (Enzyme . Forward ,
154
+ Enzyme. autodiff (fmode ,
149
155
inner_grad,
156
+ Const (rmode),
150
157
Enzyme. BatchDuplicated (θ, vdθ),
151
158
Enzyme. BatchDuplicatedNoNeed (bθ, vdbθ),
152
159
Const (f. f),
@@ -168,8 +175,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
168
175
vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
169
176
vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
170
177
171
- Enzyme. autodiff (Enzyme . Forward ,
178
+ Enzyme. autodiff (fmode ,
172
179
inner_grad,
180
+ Const (rmode),
173
181
Enzyme. BatchDuplicated (θ, vdθ),
174
182
Enzyme. BatchDuplicatedNoNeed (G, vdbθ),
175
183
Const (f. f),
@@ -189,7 +197,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
189
197
if hv == true && f. hv === nothing
190
198
function hv! (H, θ, v, p = p)
191
199
H .= Enzyme. autodiff (
192
- Enzyme . Forward , hv_f2_alloc, Duplicated (θ, v),
200
+ fmode , hv_f2_alloc, Const (rmode) , Duplicated (θ, v),
193
201
Const (f. f), Const (p)
194
202
)[1 ]
195
203
end
@@ -221,7 +229,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
221
229
Enzyme. make_zero! (Jaccache[i])
222
230
end
223
231
Enzyme. make_zero! (y)
224
- Enzyme. autodiff (Enzyme . Forward , f. cons, BatchDuplicated (y, Jaccache),
232
+ Enzyme. autodiff (fmode , f. cons, BatchDuplicated (y, Jaccache),
225
233
BatchDuplicated (θ, seeds), Const (p))
226
234
for i in eachindex (θ)
227
235
if J isa Vector
@@ -254,7 +262,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
254
262
Enzyme. make_zero! (res)
255
263
Enzyme. make_zero! (cons_res)
256
264
257
- Enzyme. autodiff (Enzyme . Reverse ,
265
+ Enzyme. autodiff (rmode ,
258
266
f. cons,
259
267
Const,
260
268
Duplicated (cons_res, v),
@@ -275,7 +283,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
275
283
Enzyme. make_zero! (res)
276
284
Enzyme. make_zero! (cons_res)
277
285
278
- Enzyme. autodiff (Enzyme . Forward ,
286
+ Enzyme. autodiff (fmode ,
279
287
f. cons,
280
288
Duplicated (cons_res, res),
281
289
Duplicated (θ, v),
@@ -297,8 +305,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
297
305
for i in 1 : num_cons
298
306
Enzyme. make_zero! (cons_bθ)
299
307
Enzyme. make_zero! .(cons_vdbθ)
300
- Enzyme. autodiff (Enzyme . Forward ,
308
+ Enzyme. autodiff (fmode ,
301
309
cons_f2,
310
+ Const (rmode),
302
311
Enzyme. BatchDuplicated (θ, cons_vdθ),
303
312
Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
304
313
Const (f. cons),
@@ -332,8 +341,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
332
341
Enzyme. make_zero! (lag_bθ)
333
342
Enzyme. make_zero! .(lag_vdbθ)
334
343
335
- Enzyme. autodiff (Enzyme . Forward ,
344
+ Enzyme. autodiff (fmode ,
336
345
lag_grad,
346
+ Const (rmode),
337
347
Enzyme. BatchDuplicated (θ, lag_vdθ),
338
348
Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
339
349
Const (lagrangian),
@@ -357,8 +367,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
357
367
Enzyme. make_zero! (lag_bθ)
358
368
Enzyme. make_zero! .(lag_vdbθ)
359
369
360
- Enzyme. autodiff (Enzyme . Forward ,
370
+ Enzyme. autodiff (fmode ,
361
371
lag_grad,
372
+ Const (rmode),
362
373
Enzyme. BatchDuplicated (θ, lag_vdθ),
363
374
Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
364
375
Const (lagrangian),
@@ -410,11 +421,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
410
421
g = false , h = false , hv = false , fg = false , fgh = false ,
411
422
cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
412
423
lag_h = false )
424
+ rmode = if adtype. mode isa Nothing
425
+ Enzyme. Reverse
426
+ else
427
+ set_runtime_activity2 (Enzyme. Reverse, adtype. mode)
428
+ end
429
+
430
+ fmode = if adtype. mode isa Nothing
431
+ Enzyme. Forward
432
+ else
433
+ set_runtime_activity2 (Enzyme. Forward, adtype. mode)
434
+ end
435
+
413
436
if g == true && f. grad === nothing
414
437
res = zeros (eltype (x), size (x))
415
438
function grad (θ, p = p)
416
439
Enzyme. make_zero! (res)
417
- Enzyme. autodiff (Enzyme . Reverse ,
440
+ Enzyme. autodiff (rmode ,
418
441
Const (firstapply),
419
442
Active,
420
443
Const (f. f),
@@ -433,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
433
456
res_fg = zeros (eltype (x), size (x))
434
457
function fg! (θ, p = p)
435
458
Enzyme. make_zero! (res_fg)
436
- y = Enzyme. autodiff (Enzyme . ReverseWithPrimal ,
459
+ y = Enzyme. autodiff (WithPrimal (rmode) ,
437
460
Const (firstapply),
438
461
Active,
439
462
Const (f. f),
@@ -457,8 +480,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457
480
Enzyme. make_zero! (bθ)
458
481
Enzyme. make_zero! .(vdbθ)
459
482
460
- Enzyme. autodiff (Enzyme . Forward ,
483
+ Enzyme. autodiff (fmode ,
461
484
inner_grad,
485
+ Const (rmode),
462
486
Enzyme. BatchDuplicated (θ, vdθ),
463
487
Enzyme. BatchDuplicated (bθ, vdbθ),
464
488
Const (f. f),
@@ -485,8 +509,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
485
509
Enzyme. make_zero! (H_fgh)
486
510
Enzyme. make_zero! .(vdbθ_fgh)
487
511
488
- Enzyme. autodiff (Enzyme . Forward ,
512
+ Enzyme. autodiff (fmode ,
489
513
inner_grad,
514
+ Const (rmode),
490
515
Enzyme. BatchDuplicated (θ, vdθ_fgh),
491
516
Enzyme. BatchDuplicatedNoNeed (G_fgh, vdbθ_fgh),
492
517
Const (f. f),
@@ -507,7 +532,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
507
532
if hv == true && f. hv === nothing
508
533
function hv! (θ, v, p = p)
509
534
return Enzyme. autodiff (
510
- Enzyme . Forward , hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
535
+ fmode , hv_f2_alloc, DuplicatedNoNeed, Const (rmode) , Duplicated (θ, v),
511
536
Const (_f), Const (f. f), Const (p)
512
537
)[1 ]
513
538
end
@@ -533,7 +558,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
533
558
for i in eachindex (Jaccache)
534
559
Enzyme. make_zero! (Jaccache[i])
535
560
end
536
- Jaccache, y = Enzyme. autodiff (Enzyme . ForwardWithPrimal , f. cons, Duplicated,
561
+ Jaccache, y = Enzyme. autodiff (WithPrimal (fmode) , f. cons, Duplicated,
537
562
BatchDuplicated (θ, seeds), Const (p))
538
563
if size (y, 1 ) == 1
539
564
return reduce (vcat, Jaccache)
@@ -555,7 +580,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
555
580
Enzyme. make_zero! (res_vjp)
556
581
Enzyme. make_zero! (cons_vjp_res)
557
582
558
- Enzyme. autodiff (Enzyme . Reverse ,
583
+ Enzyme. autodiff (WithPrimal (rmode) ,
559
584
f. cons,
560
585
Const,
561
586
Duplicated (cons_vjp_res, v),
@@ -578,7 +603,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
578
603
Enzyme. make_zero! (res_jvp)
579
604
Enzyme. make_zero! (cons_jvp_res)
580
605
581
- Enzyme. autodiff (Enzyme . Forward ,
606
+ Enzyme. autodiff (fmode ,
582
607
f. cons,
583
608
Duplicated (cons_jvp_res, res_jvp),
584
609
Duplicated (θ, v),
@@ -601,8 +626,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
601
626
return map (1 : num_cons) do i
602
627
Enzyme. make_zero! (cons_bθ)
603
628
Enzyme. make_zero! .(cons_vdbθ)
604
- Enzyme. autodiff (Enzyme . Forward ,
629
+ Enzyme. autodiff (fmode ,
605
630
cons_f2_oop,
631
+ Const (rmode),
606
632
Enzyme. BatchDuplicated (θ, cons_vdθ),
607
633
Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ),
608
634
Const (f. cons),
@@ -631,8 +657,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
631
657
Enzyme. make_zero! (lag_bθ)
632
658
Enzyme. make_zero! .(lag_vdbθ)
633
659
634
- Enzyme. autodiff (Enzyme . Forward ,
660
+ Enzyme. autodiff (fmode ,
635
661
lag_grad,
662
+ Const (rmode),
636
663
Enzyme. BatchDuplicated (θ, lag_vdθ),
637
664
Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
638
665
Const (lagrangian),
0 commit comments