Skip to content

Commit 9ec3e40

Browse files
mode in stoch ad tests for enzyme
1 parent 7e75714 commit 9ec3e40

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9898
fmode = if adtype.mode isa Nothing
9999
Enzyme.Forward
100100
else
101-
set_runtime_activity2(Enzyme.Forward.adtype.mode)
101+
set_runtime_activity2(Enzyme.Forward, adtype.mode)
102102
end
103103

104104
if g == true && f.grad === nothing

test/adtests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,17 +1172,17 @@ using MLUtils
11721172

11731173
optf = OptimizationFunction(loss, AutoEnzyme())
11741174
optf = OptimizationBase.instantiate_function(
1175-
optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true)
1175+
optf, rand(3), AutoEnzyme(mode = set_runtime_activity(Reverse)), iterate(data)[1], g = true, fg = true)
11761176
G0 = zeros(3)
1177-
@test_broken optf.grad(G0, ones(3), (x0, y0))
1178-
# stochgrads = []
1179-
# for (x,y) in data
1180-
# G = zeros(3)
1181-
# optf.grad(G, ones(3), (x,y))
1182-
# push!(stochgrads, copy(G))
1183-
# G1 = zeros(3)
1184-
# optf.fg(G1, ones(3), (x,y))
1185-
# @test G ≈ G1 rtol=1e-6
1186-
# end
1187-
# @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1
1177+
optf.grad(G0, ones(3), (x0, y0))
1178+
stochgrads = []
1179+
for (x,y) in data
1180+
G = zeros(3)
1181+
optf.grad(G, ones(3), (x,y))
1182+
push!(stochgrads, copy(G))
1183+
G1 = zeros(3)
1184+
optf.fg(G1, ones(3), (x,y))
1185+
@test G G1 rtol=1e-6
1186+
end
1187+
@test G0 sum(stochgrads) rtol=1e-1
11881188
end

0 commit comments

Comments
 (0)