Skip to content

Commit 55d2871

Browse files
authored
Use broadcasting rules from ChainRules (#89)
* remove reverse mode broadcasting rules * fix some other rules * tests * update tests, CR version * delete commented lines * rm comment
1 parent bd4da5f commit 55d2871

File tree

6 files changed

+83
-72
lines changed

6 files changed

+83
-72
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1414

1515
[compat]
16-
ChainRules = "1.5"
17-
ChainRulesCore = "1.2"
16+
ChainRules = "1.44.6"
17+
ChainRulesCore = "1.15.3"
1818
Combinatorics = "1"
1919
StaticArrays = "1"
2020
StatsBase = "0.33"

src/extra_rules.jl

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
1616
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
1717
end
1818

19-
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
19+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
2020
xs[i...], ∇getindex(xs, i)
2121
end
2222

@@ -150,12 +150,6 @@ end
150150

151151
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
152152

153-
# Skip AD'ing through the axis computation
154-
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
155-
return Base.Broadcast.instantiate(bc), Δ->begin
156-
Core.tuple(NoTangent(), Δ)
157-
end
158-
end
159153

160154

161155
using StaticArrays
@@ -187,10 +181,6 @@ end
187181

188182
@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)
189183

190-
function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
191-
getindex(A, args...), getindex(∂A, args...)
192-
end
193-
194184
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
195185
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
196186
end
@@ -225,27 +215,28 @@ struct BackMap{T}
225215
f::T
226216
end
227217
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
228-
back_apply(x, y) = x(y)
229-
back_apply_zero(x) = x(Zero())
218+
back_apply(x, y) = x(y) # this is just |> with arguments reversed
219+
back_apply_zero(x) = x(Zero()) # Zero is not defined
230220

231221
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
232222
a, b = unzip_tuple(map(BackMap(f), args))
233-
function back(Δ)
223+
function map_back(Δ)
234224
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
235225
(NoTangent(), sum(fs), xs)
236226
end
237-
function back::ZeroTangent)
238-
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
239-
(NoTangent(), sum(fs), xs)
240-
end
241-
a, back
227+
map_back::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
228+
a, map_back
242229
end
243230

231+
ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())
232+
244233
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
245234
a, b = unzip_tuple(ntuple(BackMap(f), n))
246-
a, function (Δ)
235+
function ntuple_back(Δ)
247236
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
248237
end
238+
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
239+
a, ntuple_back
249240
end
250241

251242
function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
@@ -267,5 +258,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
267258
val, Δ->(NoTangent(), NoTangent(), Δ)
268259
end
269260

270-
Base.real(z::ZeroTangent) = z # TODO should be in CRC
271-
Base.real(z::NoTangent) = z
261+
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581

src/runtime.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing
2727

2828
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
2929
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
30+
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()

src/stage1/broadcast.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,46 +28,3 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
2828
end
2929
return r
3030
end
31-
32-
# Broadcast over one element is just map
33-
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
34-
∂⃖ₙ(map, f, a)
35-
end
36-
37-
# The below is from Zygote: TODO: DO we want to do something better here?
38-
39-
accum_sum(xs::Nothing; dims = :) = NoTangent()
40-
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
41-
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
42-
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
43-
accum_sum(xs::Number; dims = :) = xs
44-
45-
# https://github.com/FluxML/Zygote.jl/issues/594
46-
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
47-
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
48-
end
49-
50-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
51-
52-
unbroadcast(x::AbstractArray, x̄) =
53-
size(x) == size(x̄) ?:
54-
length(x) == length(x̄) ? trim(x, x̄) :
55-
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
56-
57-
unbroadcast(x::Number, x̄) = accum_sum(x̄)
58-
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
59-
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
60-
61-
unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
62-
63-
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
64-
65-
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...)
66-
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
67-
end
68-
69-
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
70-
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
71-
72-
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
73-
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end

src/stage1/generated.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
315315
end
316316

317317
# TODO: Temporary - make better
318-
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
318+
function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N}
319319
getindex(a, inds...), let
320320
EvenOddOdd{1, c_order(N)}(
321321
(@Base.constprop :aggressive Δ->begin
322322
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
323323
BB = zero(a)
324-
BB[inds...] = Δ
324+
BB[inds...] = unthunk(Δ)
325325
(NoTangent(), BB, map(x->NoTangent(), inds)...)
326326
end),
327327
(@Base.constprop :aggressive (_, Δ, _)->begin
@@ -334,6 +334,7 @@ struct tuple_back{M}; end
334334
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
335335
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
336336
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
337+
(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ))
337338

338339
function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
339340
Core.tuple(args...),

test/runtests.jl

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
9292
@test @inferred(sin'(1.0)) == cos(1.0)
9393
@test @inferred(sin''(1.0)) == -sin(1.0)
9494
@test sin'''(1.0) == -cos(1.0)
95-
@test sin''''(1.0) == sin(1.0) broken = VERSION >= v"1.8"
96-
@test sin'''''(1.0) == cos(1.0) broken = VERSION >= v"1.8"
97-
@test sin''''''(1.0) == -sin(1.0) broken = VERSION >= v"1.8"
95+
@test sin''''(1.0) == sin(1.0)
96+
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
97+
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"
9898

9999
f_getfield(x) = getfield((x,), 1)
100100
@test f_getfield'(1) == 1
@@ -219,6 +219,68 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
219219
@test z45 2.0
220220
@test delta45 1.0
221221

222+
# PR #82 - getindex on non-numeric arrays
223+
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}
224+
225+
@testset "broadcast" begin
226+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
227+
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
228+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
229+
230+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
231+
exp_log(x) = exp(log(x))
232+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
233+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
234+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
235+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
236+
237+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
238+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
239+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
240+
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12] # must not take the * fast path
241+
242+
@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
243+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
244+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
245+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
246+
247+
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
248+
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
249+
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)
250+
251+
@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output
252+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
253+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
254+
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input
255+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
256+
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero
257+
258+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
259+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
260+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
261+
@test tup_adj[2] isa Transpose
262+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
263+
264+
@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
265+
end
266+
267+
@testset "broadcast, 2nd order" begin
268+
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
269+
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
270+
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order
271+
272+
@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
273+
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
274+
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]
275+
276+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
277+
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
278+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
279+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
280+
281+
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
282+
end
283+
222284
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
223285
#include("pinn.jl")
224286

0 commit comments

Comments
 (0)