@@ -84,6 +84,8 @@ function set_runtime_activity2(
84
84
a:: Mode1 , :: Enzyme.Mode{ABI, Err, RTA} ) where {Mode1, ABI, Err, RTA}
85
85
Enzyme. set_runtime_activity (a, RTA)
86
86
end
87
+ function_annotation (:: Nothing ) = Nothing
88
+ function_annotation (:: AutoEnzyme{<:Any, A} ) where A = A
87
89
function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
88
90
adtype:: AutoEnzyme , p, num_cons = 0 ;
89
91
g = false , h = false , hv = false , fg = false , fgh = false ,
@@ -101,6 +103,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
101
103
set_runtime_activity2 (Enzyme. Forward, adtype. mode)
102
104
end
103
105
106
+ func_annot = function_annotation (adtype)
107
+
104
108
if g == true && f. grad === nothing
105
109
function grad (res, θ, p = p)
106
110
Enzyme. make_zero! (res)
@@ -217,6 +221,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
217
221
# if num_cons > length(x)
218
222
seeds = Enzyme. onehot (x)
219
223
Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
224
+ basefunc = f. cons
225
+ if func_annot <: Enzyme.Const
226
+ basefunc = Enzyme. Const (basefunc)
227
+ elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
228
+ basefunc = Enzyme. BatchDuplicated (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
229
+ elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
230
+ basefunc = Enzyme. BatchDuplicatedNoNeed (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
231
+ end
220
232
# else
221
233
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
222
234
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
@@ -225,11 +237,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
225
237
y = zeros (eltype (x), num_cons)
226
238
227
239
function cons_j! (J, θ)
228
- for i in eachindex ( Jaccache)
229
- Enzyme. make_zero! (Jaccache[i] )
240
+ for jc in Jaccache
241
+ Enzyme. make_zero! (jc )
230
242
end
231
243
Enzyme. make_zero! (y)
232
- Enzyme. autodiff (fmode, f. cons, BatchDuplicated (y, Jaccache),
244
+ if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
245
+ for bf in basefunc. dval
246
+ Enzyme. make_zero! (bf)
247
+ end
248
+ end
249
+ Enzyme. autodiff (fmode, basefunc , BatchDuplicated (y, Jaccache),
233
250
BatchDuplicated (θ, seeds), Const (p))
234
251
for i in eachindex (θ)
235
252
if J isa Vector
0 commit comments