From e5547aa4ec40613b9545009830dc2d0a6a943a8d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Sep 2022 20:26:18 -0400 Subject: [PATCH 1/3] =?UTF-8?q?unthunk=20for=20=E2=88=87batchnorm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cuda/cudnn.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 3414939853..9e6bdb53a0 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -14,7 +14,7 @@ end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) function batchnorm_pullback(Δ) - grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...) + grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) end y, batchnorm_pullback From 3e12946401feac1c46c5022c7a64cd65d866da56 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Sep 2022 20:28:41 -0400 Subject: [PATCH 2/3] unthunk some rrules --- src/functor.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index bfa075a6b8..13adbe13ff 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -119,11 +119,11 @@ adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray) - Array(x), d -> (NoTangent(), CUDA.cu(d),) + Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),) end function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),) + adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),) end # CPU/GPU movement conveniences @@ -227,3 +227,4 @@ f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures @functor Cholesky trainable(c::Cholesky) = () + From e4c650fa098833ae3fb93ef3601fe4c3a6808276 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Sep 2022 20:30:13 -0400 Subject: [PATCH 3/3] unthunk in multigate rrule --- src/layers/recurrent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 760933bb96..c3b89f33a7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -9,7 +9,7 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N) function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c) function multigate_pullback(dy) dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x) - foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ + foreach(multigate(dx, h, c), unthunk(dy)) do dxᵢ, dyᵢ dyᵢ isa AbstractZero && return @. dxᵢ += dyᵢ end