Skip to content

Commit b13c812

Browse files
Apply suggestions from code review
1 parent 97ec67b commit b13c812

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
44
version = "2.3.0"
55

6-
76
[deps]
87
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
98
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

ext/OptimizationEnzymeExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9191
rmode = if adtype.mode isa Nothing
9292
Enzyme.Reverse
9393
else
94-
set_runtime_activity2(Enzyme.Reverse)
94+
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
9595
end
9696

9797
fmode = if adtype.mode isa Nothing
9898
Enzyme.Forward
9999
else
100-
set_runtime_activity2(Enzyme.Forward)
100+
set_runtime_activity2(Enzyme.Forward. adtype.mode)
101101
end
102102

103103
if g == true && f.grad === nothing
@@ -423,13 +423,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
423423
rmode = if adtype.mode isa Nothing
424424
Enzyme.Reverse
425425
else
426-
set_runtime_activity2(Enzyme.Reverse)
426+
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
427427
end
428428

429429
fmode = if adtype.mode isa Nothing
430430
Enzyme.Forward
431431
else
432-
set_runtime_activity2(Enzyme.Forward)
432+
set_runtime_activity2(Enzyme.Forward, adtype.mode)
433433
end
434434

435435
if g == true && f.grad === nothing

test/adtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using OptimizationBase, Test, DifferentiationInterface, SparseArrays, Symbolics
22
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
3-
using ModelingToolkit, Enzyme, Random
3+
using Enzyme, Random
44

55
x0 = zeros(2)
66
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
@@ -1174,8 +1174,8 @@ using MLUtils
11741174
optf = OptimizationBase.instantiate_function(
11751175
optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true)
11761176
G0 = zeros(3)
1177-
@test_broken optf.grad(G0, ones(3), (x, y))
1178-
stochgrads = []
1177+
@test_broken optf.grad(G0, ones(3), (x0, y0))
1178+
# stochgrads = []
11791179
# for (x,y) in data
11801180
# G = zeros(3)
11811181
# optf.grad(G, ones(3), (x,y))

0 commit comments

Comments
 (0)