Skip to content

Commit 88cc300

Browse files
Merge pull request #142 from wsmoses/bfunc
Enzyme: add func_annotation
2 parents 6d96450 + d11da6c commit 88cc300

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ function set_runtime_activity2(
8484
a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA}
8585
Enzyme.set_runtime_activity(a, RTA)
8686
end
87+
function_annotation(::Nothing) = Nothing
88+
function_annotation(::AutoEnzyme{<:Any, A}) where A = A
8789
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
8890
adtype::AutoEnzyme, p, num_cons = 0;
8991
g = false, h = false, hv = false, fg = false, fgh = false,
@@ -101,6 +103,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
101103
set_runtime_activity2(Enzyme.Forward, adtype.mode)
102104
end
103105

106+
func_annot = function_annotation(adtype)
107+
104108
if g == true && f.grad === nothing
105109
function grad(res, θ, p = p)
106110
Enzyme.make_zero!(res)
@@ -217,6 +221,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
217221
# if num_cons > length(x)
218222
seeds = Enzyme.onehot(x)
219223
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
224+
basefunc = f.cons
225+
if func_annot <: Enzyme.Const
226+
basefunc = Enzyme.Const(basefunc)
227+
elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
228+
basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
229+
elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
230+
basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
231+
end
220232
# else
221233
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
222234
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
@@ -225,11 +237,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
225237
y = zeros(eltype(x), num_cons)
226238

227239
function cons_j!(J, θ)
228-
for i in eachindex(Jaccache)
229-
Enzyme.make_zero!(Jaccache[i])
240+
for jc in Jaccache
241+
Enzyme.make_zero!(jc)
230242
end
231243
Enzyme.make_zero!(y)
232-
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
244+
if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
245+
for bf in basefunc.dval
246+
Enzyme.make_zero!(bf)
247+
end
248+
end
249+
Enzyme.autodiff(fmode, basefunc , BatchDuplicated(y, Jaccache),
233250
BatchDuplicated(θ, seeds), Const(p))
234251
for i in eachindex(θ)
235252
if J isa Vector

0 commit comments

Comments
 (0)