Skip to content

Commit d7ed5de

Browse files
Merge branch 'main' into compathelper/new_version/2024-10-03-00-13-16-506-00533635199
2 parents 2ab2e8f + 7b77f3a commit d7ed5de

15 files changed

+280
-238
lines changed

.github/dependabot.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ updates:
55
directory: "/" # Location of package manifests
66
schedule:
77
interval: "weekly"
8+
ignore:
9+
- dependency-name: "crate-ci/typos"
10+
update-types: ["version-update:semver-patch", "version-update:semver-minor"]

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
fail-fast: false
2424
matrix:
2525
version:
26+
- '1.10'
2627
- '1'
2728
os:
2829
- ubuntu-latest

.github/workflows/SpellCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ jobs:
1010
- name: Checkout Actions Repository
1111
uses: actions/checkout@v4
1212
- name: Check spelling
13-
uses: crate-ci/typos@v1.24.6
13+
uses: crate-ci/typos@v1.26.0

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
4-
version = "2.2.0"
4+
version = "2.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -50,6 +50,7 @@ FastClosures = "0.3"
5050
FiniteDiff = "2.12"
5151
ForwardDiff = "0.10.26"
5252
LinearAlgebra = "1.9, 1.10"
53+
MLDataDevices = "1"
5354
MLUtils = "0.4"
5455
ModelingToolkit = "9"
5556
PDMats = "0.11"

ext/OptimizationEnzymeExt.jl

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using Core: Vararg
1717
end
1818
end
1919

20-
function inner_grad(θ, bθ, f, p)
21-
Enzyme.autodiff_deferred(Enzyme.Reverse,
20+
function inner_grad(mode::Mode, θ, bθ, f, p) where {Mode}
21+
Enzyme.autodiff(mode,
2222
Const(firstapply),
2323
Active,
2424
Const(f),
@@ -28,19 +28,9 @@ function inner_grad(θ, bθ, f, p)
2828
return nothing
2929
end
3030

31-
function inner_grad_primal(θ, bθ, f, p)
32-
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
33-
Const(firstapply),
34-
Active,
35-
Const(f),
36-
Enzyme.Duplicated(θ, bθ),
37-
Const(p)
38-
)[2]
39-
end
40-
41-
function hv_f2_alloc(x, f, p)
31+
function hv_f2_alloc(mode::Mode, x, f, p) where {Mode}
4232
dx = Enzyme.make_zero(x)
43-
Enzyme.autodiff_deferred(Enzyme.Reverse,
33+
Enzyme.autodiff(mode,
4434
Const(firstapply),
4535
Active,
4636
Const(f),
@@ -57,9 +47,9 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5747
return res[i]
5848
end
5949

60-
function cons_f2(x, dx, fcons, p, num_cons, i)
50+
function cons_f2(mode, x, dx, fcons, p, num_cons, i)
6151
Enzyme.autodiff_deferred(
62-
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
52+
mode, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
6353
Const(fcons), Const(p), Const(num_cons), Const(i))
6454
return nothing
6555
end
@@ -70,9 +60,9 @@ function inner_cons_oop(
7060
return fcons(x, p)[i]
7161
end
7262

73-
function cons_f2_oop(x, dx, fcons, p, i)
63+
function cons_f2_oop(mode, x, dx, fcons, p, i)
7464
Enzyme.autodiff_deferred(
75-
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
65+
mode, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
7666
Const(fcons), Const(p), Const(i))
7767
return nothing
7868
end
@@ -83,22 +73,38 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8373
return σ * _f(x, p) + dot(λ, res)
8474
end
8575

86-
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
76+
function lag_grad(mode, x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
8777
Enzyme.autodiff_deferred(
88-
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
78+
mode, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
8979
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
9080
return nothing
9181
end
9282

83+
function set_runtime_activity2(
84+
a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA}
85+
Enzyme.set_runtime_activity(a, RTA)
86+
end
9387
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9488
adtype::AutoEnzyme, p, num_cons = 0;
9589
g = false, h = false, hv = false, fg = false, fgh = false,
9690
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
9791
lag_h = false)
92+
rmode = if adtype.mode isa Nothing
93+
Enzyme.Reverse
94+
else
95+
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
96+
end
97+
98+
fmode = if adtype.mode isa Nothing
99+
Enzyme.Forward
100+
else
101+
set_runtime_activity2(Enzyme.Forward, adtype.mode)
102+
end
103+
98104
if g == true && f.grad === nothing
99105
function grad(res, θ, p = p)
100106
Enzyme.make_zero!(res)
101-
Enzyme.autodiff(Enzyme.Reverse,
107+
Enzyme.autodiff(rmode,
102108
Const(firstapply),
103109
Active,
104110
Const(f.f),
@@ -115,7 +121,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
115121
if fg == true && f.fg === nothing
116122
function fg!(res, θ, p = p)
117123
Enzyme.make_zero!(res)
118-
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
124+
y = Enzyme.autodiff(WithPrimal(rmode),
119125
Const(firstapply),
120126
Active,
121127
Const(f.f),
@@ -145,8 +151,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
145151
Enzyme.make_zero!(bθ)
146152
Enzyme.make_zero!.(vdbθ)
147153

148-
Enzyme.autodiff(Enzyme.Forward,
154+
Enzyme.autodiff(fmode,
149155
inner_grad,
156+
Const(rmode),
150157
Enzyme.BatchDuplicated(θ, vdθ),
151158
Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ),
152159
Const(f.f),
@@ -168,8 +175,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
168175
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
169176
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
170177

171-
Enzyme.autodiff(Enzyme.Forward,
178+
Enzyme.autodiff(fmode,
172179
inner_grad,
180+
Const(rmode),
173181
Enzyme.BatchDuplicated(θ, vdθ),
174182
Enzyme.BatchDuplicatedNoNeed(G, vdbθ),
175183
Const(f.f),
@@ -189,7 +197,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
189197
if hv == true && f.hv === nothing
190198
function hv!(H, θ, v, p = p)
191199
H .= Enzyme.autodiff(
192-
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
200+
fmode, hv_f2_alloc, Const(rmode), Duplicated(θ, v),
193201
Const(f.f), Const(p)
194202
)[1]
195203
end
@@ -221,7 +229,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
221229
Enzyme.make_zero!(Jaccache[i])
222230
end
223231
Enzyme.make_zero!(y)
224-
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
232+
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
225233
BatchDuplicated(θ, seeds), Const(p))
226234
for i in eachindex(θ)
227235
if J isa Vector
@@ -254,7 +262,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
254262
Enzyme.make_zero!(res)
255263
Enzyme.make_zero!(cons_res)
256264

257-
Enzyme.autodiff(Enzyme.Reverse,
265+
Enzyme.autodiff(rmode,
258266
f.cons,
259267
Const,
260268
Duplicated(cons_res, v),
@@ -275,7 +283,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
275283
Enzyme.make_zero!(res)
276284
Enzyme.make_zero!(cons_res)
277285

278-
Enzyme.autodiff(Enzyme.Forward,
286+
Enzyme.autodiff(fmode,
279287
f.cons,
280288
Duplicated(cons_res, res),
281289
Duplicated(θ, v),
@@ -297,8 +305,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
297305
for i in 1:num_cons
298306
Enzyme.make_zero!(cons_bθ)
299307
Enzyme.make_zero!.(cons_vdbθ)
300-
Enzyme.autodiff(Enzyme.Forward,
308+
Enzyme.autodiff(fmode,
301309
cons_f2,
310+
Const(rmode),
302311
Enzyme.BatchDuplicated(θ, cons_vdθ),
303312
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
304313
Const(f.cons),
@@ -332,8 +341,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
332341
Enzyme.make_zero!(lag_bθ)
333342
Enzyme.make_zero!.(lag_vdbθ)
334343

335-
Enzyme.autodiff(Enzyme.Forward,
344+
Enzyme.autodiff(fmode,
336345
lag_grad,
346+
Const(rmode),
337347
Enzyme.BatchDuplicated(θ, lag_vdθ),
338348
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
339349
Const(lagrangian),
@@ -357,8 +367,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
357367
Enzyme.make_zero!(lag_bθ)
358368
Enzyme.make_zero!.(lag_vdbθ)
359369

360-
Enzyme.autodiff(Enzyme.Forward,
370+
Enzyme.autodiff(fmode,
361371
lag_grad,
372+
Const(rmode),
362373
Enzyme.BatchDuplicated(θ, lag_vdθ),
363374
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
364375
Const(lagrangian),
@@ -410,11 +421,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
410421
g = false, h = false, hv = false, fg = false, fgh = false,
411422
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
412423
lag_h = false)
424+
rmode = if adtype.mode isa Nothing
425+
Enzyme.Reverse
426+
else
427+
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
428+
end
429+
430+
fmode = if adtype.mode isa Nothing
431+
Enzyme.Forward
432+
else
433+
set_runtime_activity2(Enzyme.Forward, adtype.mode)
434+
end
435+
413436
if g == true && f.grad === nothing
414437
res = zeros(eltype(x), size(x))
415438
function grad(θ, p = p)
416439
Enzyme.make_zero!(res)
417-
Enzyme.autodiff(Enzyme.Reverse,
440+
Enzyme.autodiff(rmode,
418441
Const(firstapply),
419442
Active,
420443
Const(f.f),
@@ -433,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
433456
res_fg = zeros(eltype(x), size(x))
434457
function fg!(θ, p = p)
435458
Enzyme.make_zero!(res_fg)
436-
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
459+
y = Enzyme.autodiff(WithPrimal(rmode),
437460
Const(firstapply),
438461
Active,
439462
Const(f.f),
@@ -457,8 +480,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457480
Enzyme.make_zero!(bθ)
458481
Enzyme.make_zero!.(vdbθ)
459482

460-
Enzyme.autodiff(Enzyme.Forward,
483+
Enzyme.autodiff(fmode,
461484
inner_grad,
485+
Const(rmode),
462486
Enzyme.BatchDuplicated(θ, vdθ),
463487
Enzyme.BatchDuplicated(bθ, vdbθ),
464488
Const(f.f),
@@ -485,8 +509,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
485509
Enzyme.make_zero!(H_fgh)
486510
Enzyme.make_zero!.(vdbθ_fgh)
487511

488-
Enzyme.autodiff(Enzyme.Forward,
512+
Enzyme.autodiff(fmode,
489513
inner_grad,
514+
Const(rmode),
490515
Enzyme.BatchDuplicated(θ, vdθ_fgh),
491516
Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh),
492517
Const(f.f),
@@ -507,7 +532,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
507532
if hv == true && f.hv === nothing
508533
function hv!(θ, v, p = p)
509534
return Enzyme.autodiff(
510-
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
535+
fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(θ, v),
511536
Const(_f), Const(f.f), Const(p)
512537
)[1]
513538
end
@@ -533,7 +558,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
533558
for i in eachindex(Jaccache)
534559
Enzyme.make_zero!(Jaccache[i])
535560
end
536-
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
561+
Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated,
537562
BatchDuplicated(θ, seeds), Const(p))
538563
if size(y, 1) == 1
539564
return reduce(vcat, Jaccache)
@@ -555,7 +580,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
555580
Enzyme.make_zero!(res_vjp)
556581
Enzyme.make_zero!(cons_vjp_res)
557582

558-
Enzyme.autodiff(Enzyme.Reverse,
583+
Enzyme.autodiff(WithPrimal(rmode),
559584
f.cons,
560585
Const,
561586
Duplicated(cons_vjp_res, v),
@@ -578,7 +603,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
578603
Enzyme.make_zero!(res_jvp)
579604
Enzyme.make_zero!(cons_jvp_res)
580605

581-
Enzyme.autodiff(Enzyme.Forward,
606+
Enzyme.autodiff(fmode,
582607
f.cons,
583608
Duplicated(cons_jvp_res, res_jvp),
584609
Duplicated(θ, v),
@@ -601,8 +626,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
601626
return map(1:num_cons) do i
602627
Enzyme.make_zero!(cons_bθ)
603628
Enzyme.make_zero!.(cons_vdbθ)
604-
Enzyme.autodiff(Enzyme.Forward,
629+
Enzyme.autodiff(fmode,
605630
cons_f2_oop,
631+
Const(rmode),
606632
Enzyme.BatchDuplicated(θ, cons_vdθ),
607633
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
608634
Const(f.cons),
@@ -631,8 +657,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
631657
Enzyme.make_zero!(lag_bθ)
632658
Enzyme.make_zero!.(lag_vdbθ)
633659

634-
Enzyme.autodiff(Enzyme.Forward,
660+
Enzyme.autodiff(fmode,
635661
lag_grad,
662+
Const(rmode),
636663
Enzyme.BatchDuplicated(θ, lag_vdθ),
637664
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
638665
Const(lagrangian),

0 commit comments

Comments
 (0)