@@ -6,6 +6,9 @@ using DiffResults
6
6
using ReverseDiff
7
7
using Test
8
8
9
+ struct MyStruct end
10
+ f (:: MyStruct , x) = sum (4 x .+ 1 )
11
+ f (x, y:: MyStruct ) = sum (4 x .+ 1 )
9
12
f (x) = sum (4 x .+ 1 )
10
13
11
14
function ChainRulesCore. rrule (:: typeof (f), x)
@@ -20,21 +23,37 @@ function ChainRulesCore.rrule(::typeof(f), x)
20
23
rather than 4 when we compute the derivative of `f`, it means
21
24
the importing mechanism works.
22
25
=#
23
- return ChainRulesCore. NoTangent (), fill (3 * d, size (x))
26
+ return NoTangent (), fill (3 * d, size (x))
27
+ end
28
+ return r, back
29
+ end
30
+ function ChainRulesCore. rrule (:: typeof (f), :: MyStruct , x)
31
+ r = f (MyStruct (), x)
32
+ function back (d)
33
+ return NoTangent (), NoTangent (), fill (3 * d, size (x))
34
+ end
35
+ return r, back
36
+ end
37
+ function ChainRulesCore. rrule (:: typeof (f), x, :: MyStruct )
38
+ r = f (x, MyStruct ())
39
+ function back (d)
40
+ return NoTangent (), fill (3 * d, size (x)), NoTangent ()
24
41
end
25
42
return r, back
26
43
end
27
44
28
45
ReverseDiff. @grad_from_chainrules f (x:: ReverseDiff.TrackedArray )
29
-
46
+ # test arg type hygiene
47
+ ReverseDiff. @grad_from_chainrules f (:: MyStruct , x:: ReverseDiff.TrackedArray )
48
+ ReverseDiff. @grad_from_chainrules f (x:: ReverseDiff.TrackedArray , y:: MyStruct )
30
49
31
50
g (x, y) = sum (4 x .+ 4 y)
32
51
33
52
function ChainRulesCore. rrule (:: typeof (g), x, y)
34
53
r = g (x, y)
35
54
function back (d)
36
55
# same as above, use 3 and 5 as the derivatives
37
- return ChainRulesCore . NoTangent (), fill (3 * d, size (x)), fill (5 * d, size (x))
56
+ return NoTangent (), fill (3 * d, size (x)), fill (5 * d, size (x))
38
57
end
39
58
return r, back
40
59
end
@@ -93,6 +112,19 @@ ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.
93
112
94
113
end
95
114
115
+ @testset " custom struct input" begin
116
+ input = rand (3 , 3 )
117
+ output, back = ChainRulesCore. rrule (f, MyStruct (), input);
118
+ _, _, d = back (1 )
119
+ @test output == f (MyStruct (), input)
120
+ @test d == fill (3 , size (input))
121
+
122
+ output, back = ChainRulesCore. rrule (f, input, MyStruct ());
123
+ _, d, _ = back (1 )
124
+ @test output == f (input, MyStruct ())
125
+ @test d == fill (3 , size (input))
126
+ end
127
+
96
128
# ## Tape test
97
129
@testset " Tape test: Ensure ordinary call is not tracked" begin
98
130
tp = ReverseDiff. InstructionTape ()
@@ -112,7 +144,7 @@ f_vararg(x, args...) = sum(4x .+ sum(args))
112
144
function ChainRulesCore. rrule (:: typeof (f_vararg), x, args... )
113
145
r = f_vararg (x, args... )
114
146
function back (d)
115
- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
147
+ return NoTangent (), fill (3 * d, size (x))
116
148
end
117
149
return r, back
118
150
end
@@ -136,7 +168,7 @@ f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
136
168
function ChainRulesCore. rrule (:: typeof (f_kw), x, args... ; k= 1 , kwargs... )
137
169
r = f_kw (x, args... ; k= k, kwargs... )
138
170
function back (d)
139
- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
171
+ return NoTangent (), fill (3 * d, size (x))
140
172
end
141
173
return r, back
142
174
end
@@ -175,20 +207,20 @@ end
175
207
# ## Isolated Scope
176
208
module IsolatedModuleForTestingScoping
177
209
using ChainRulesCore
178
- using ReverseDiff: @grad_from_chainrules
210
+ using ReverseDiff: ReverseDiff, @grad_from_chainrules
179
211
180
212
f (x) = sum (4 x .+ 1 )
181
213
182
214
function ChainRulesCore. rrule (:: typeof (f), x)
183
215
r = f (x)
184
216
function back (d)
185
217
# return a distinguishable but improper grad
186
- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
218
+ return NoTangent (), fill (3 * d, size (x))
187
219
end
188
220
return r, back
189
221
end
190
222
191
- @grad_from_chainrules f (x:: TrackedArray )
223
+ @grad_from_chainrules f (x:: ReverseDiff. TrackedArray )
192
224
193
225
module SubModule
194
226
using Test
0 commit comments