From bb14037f717d3a15a98cc445a0e93fd55b687240 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 28 Oct 2021 23:35:25 +0200 Subject: [PATCH] Opt out of CR.rrule if pullback is not defined --- src/eval.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/eval.jl b/src/eval.jl index d66ac55..68eadd1 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -48,15 +48,18 @@ import ChainRulesCore function ChainRulesCore.rrule(ev::Eval, args...) Z = ev.fwd(args...) - Z, function tullio_back(Δ) - isnothing(ev.rev) && error("no gradient definition here!") + function tullio_back(Δ) dxs = map(ev.rev(Δ, Z, args...)) do dx dx === nothing ? ChainRulesCore.ZeroTangent() : dx end - tuple(ChainRulesCore.ZeroTangent(), dxs...) + return (ChainRulesCore.NoTangent(), dxs...) end + return Z, tullio_back end +# without gradient definition we let the AD system differentiate the function +ChainRulesCore.@opt_out ChainRulesCore.rrule(ev::Eval{<:Any,Nothing}, args...) + @init @require FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" begin using .FillArrays: Fill # used by Zygote Tullio.promote_storage(::Type{T}, ::Type{F}) where {T, F<:Fill} = T