Skip to content

Commit 105290f

Browse files
Allow custom struct args to grad_from_chainrules macro (#232)
* allow custom struct args to grad_from_chainrules * fix test * bump version * fix test * Update test/ChainRulesTests.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/macros.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/macros.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 65cd309 commit 105290f

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.14.6"
3+
version = "1.15.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/macros.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,19 @@ macro grad_from_chainrules(fcall)
317317
Meta.isexpr(fcall, :call) && length(fcall.args) >= 2 ||
318318
error("`@grad_from_chainrules` has to be applied to a function signature")
319319
f = esc(fcall.args[1])
320-
xs = fcall.args[2:end]
320+
xs = map(fcall.args[2:end]) do x
321+
if Meta.isexpr(x, :(::))
322+
if length(x.args) == 1 # ::T without var name
323+
return :($(gensym())::$(esc(x.args[1])))
324+
else # x::T
325+
@assert length(x.args) == 2
326+
return :($(x.args[1])::$(esc(x.args[2])))
327+
end
328+
else
329+
return x
330+
end
331+
end
321332
args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args(f, xs)
322-
323333
return quote
324334
$f($(args_l...)) = ReverseDiff.track($(args_r...))
325335
function ReverseDiff.track($(args_track...))

test/ChainRulesTests.jl

+40-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using DiffResults
66
using ReverseDiff
77
using Test
88

9+
struct MyStruct end
10+
f(::MyStruct, x) = sum(4x .+ 1)
11+
f(x, y::MyStruct) = sum(4x .+ 1)
912
f(x) = sum(4x .+ 1)
1013

1114
function ChainRulesCore.rrule(::typeof(f), x)
@@ -20,21 +23,37 @@ function ChainRulesCore.rrule(::typeof(f), x)
2023
rather than 4 when we compute the derivative of `f`, it means
2124
the importing mechanism works.
2225
=#
23-
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
26+
return NoTangent(), fill(3 * d, size(x))
27+
end
28+
return r, back
29+
end
30+
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
31+
r = f(MyStruct(), x)
32+
function back(d)
33+
return NoTangent(), NoTangent(), fill(3 * d, size(x))
34+
end
35+
return r, back
36+
end
37+
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
38+
r = f(x, MyStruct())
39+
function back(d)
40+
return NoTangent(), fill(3 * d, size(x)), NoTangent()
2441
end
2542
return r, back
2643
end
2744

2845
ReverseDiff.@grad_from_chainrules f(x::ReverseDiff.TrackedArray)
29-
46+
# test arg type hygiene
47+
ReverseDiff.@grad_from_chainrules f(::MyStruct, x::ReverseDiff.TrackedArray)
48+
ReverseDiff.@grad_from_chainrules f(x::ReverseDiff.TrackedArray, y::MyStruct)
3049

3150
g(x, y) = sum(4x .+ 4y)
3251

3352
function ChainRulesCore.rrule(::typeof(g), x, y)
3453
r = g(x, y)
3554
function back(d)
3655
# same as above, use 3 and 5 as the derivatives
37-
return ChainRulesCore.NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
56+
return NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
3857
end
3958
return r, back
4059
end
@@ -93,6 +112,19 @@ ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.
93112

94113
end
95114

115+
@testset "custom struct input" begin
116+
input = rand(3, 3)
117+
output, back = ChainRulesCore.rrule(f, MyStruct(), input);
118+
_, _, d = back(1)
119+
@test output == f(MyStruct(), input)
120+
@test d == fill(3, size(input))
121+
122+
output, back = ChainRulesCore.rrule(f, input, MyStruct());
123+
_, d, _ = back(1)
124+
@test output == f(input, MyStruct())
125+
@test d == fill(3, size(input))
126+
end
127+
96128
### Tape test
97129
@testset "Tape test: Ensure ordinary call is not tracked" begin
98130
tp = ReverseDiff.InstructionTape()
@@ -112,7 +144,7 @@ f_vararg(x, args...) = sum(4x .+ sum(args))
112144
function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
113145
r = f_vararg(x, args...)
114146
function back(d)
115-
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
147+
return NoTangent(), fill(3 * d, size(x))
116148
end
117149
return r, back
118150
end
@@ -136,7 +168,7 @@ f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
136168
function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
137169
r = f_kw(x, args...; k=k, kwargs...)
138170
function back(d)
139-
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
171+
return NoTangent(), fill(3 * d, size(x))
140172
end
141173
return r, back
142174
end
@@ -175,20 +207,20 @@ end
175207
### Isolated Scope
176208
module IsolatedModuleForTestingScoping
177209
using ChainRulesCore
178-
using ReverseDiff: @grad_from_chainrules
210+
using ReverseDiff: ReverseDiff, @grad_from_chainrules
179211

180212
f(x) = sum(4x .+ 1)
181213

182214
function ChainRulesCore.rrule(::typeof(f), x)
183215
r = f(x)
184216
function back(d)
185217
# return a distinguishable but improper grad
186-
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
218+
return NoTangent(), fill(3 * d, size(x))
187219
end
188220
return r, back
189221
end
190222

191-
@grad_from_chainrules f(x::TrackedArray)
223+
@grad_from_chainrules f(x::ReverseDiff.TrackedArray)
192224

193225
module SubModule
194226
using Test

0 commit comments

Comments
 (0)