diff --git a/src/functor.jl b/src/functor.jl index d05489104f..bfa075a6b8 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -88,6 +88,9 @@ function params(m...) return ps end +# Allows caching of the parameters when params is called within gradient() to fix #2040. +@non_differentiable params(m...) + struct FluxCUDAAdaptor end adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))