Skip to content

Commit f06b776

Browse files
authored
Improve DiffRules integration and tests (#209)
* Improve DiffRules integration and tests * Bump version * Try to remove suspicious line
1 parent 8ac1f7d commit f06b776

File tree

7 files changed

+168
-59
lines changed

7 files changed

+168
-59
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.14.3"
3+
version = "1.14.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ReverseDiff.jl

+11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ const REAL_TYPES = (:Bool, :Integer, :(Irrational{:ℯ}), :(Irrational{:π}), :R
2929
const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
3030
const SKIPPED_BINARY_SCALAR_FUNCS = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
3131

32+
# Some functions with derivatives in DiffRules are not supported
33+
# For instance, ReverseDiff does not support functions with complex results and derivatives
34+
const SKIPPED_DIFFRULES = Tuple{Symbol,Symbol}[
35+
(:SpecialFunctions, :hankelh1),
36+
(:SpecialFunctions, :hankelh1x),
37+
(:SpecialFunctions, :hankelh2),
38+
(:SpecialFunctions, :hankelh2x),
39+
(:SpecialFunctions, :besselh),
40+
(:SpecialFunctions, :besselhx),
41+
]
42+
3243
include("tape.jl")
3344
include("tracked.jl")
3445
include("macros.jl")

src/derivatives/elementwise.jl

+106-22
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ for g! in (:map!, :broadcast!), (M, f, arity) in DiffRules.diffrules(; filter_mo
7272
@warn "$M.$f is not available and hence rule for it can not be defined"
7373
continue # Skip rules for methods not defined in the current scope
7474
end
75+
(M, f) in SKIPPED_DIFFRULES && continue
7576
if arity == 1
7677
@eval @inline Base.$(g!)(f::typeof($M.$f), out::TrackedArray, t::TrackedArray) = $(g!)(ForwardOptimize(f), out, t)
7778
elseif arity == 2
@@ -122,23 +123,53 @@ for (g!, g) in ((:map!, :map), (:broadcast!, :broadcast))
122123
return out
123124
end
124125
end
125-
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
126-
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray{S}, x::$(T){X}, y::$A) where {F,S,X}
127-
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
128-
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
126+
for A in ARRAY_TYPES
127+
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedReal{X,D}, y::$A) where {F,X,D}
128+
result = DiffResults.DiffResult(zero(X), zero(D))
129+
df = let result=result
130+
(vx, vy) -> let vy=vy
131+
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
132+
end
133+
end
129134
results = $(g)(df, value(x), value(y))
130135
map!(DiffResult.value, value(out), results)
131136
cache = (results, df, index_bound(x, out), index_bound(y, out))
132-
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
137+
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
133138
return out
134139
end
135-
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::$(T){Y}) where {F,Y}
136-
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
137-
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
140+
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
141+
result = DiffResults.DiffResult(zero(Y), zero(D))
142+
df = let result=result
143+
(vx, vy) -> let vx=vx
144+
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
145+
end
146+
end
138147
results = $(g)(df, value(x), value(y))
139148
map!(DiffResult.value, value(out), results)
140149
cache = (results, df, index_bound(x, out), index_bound(y, out))
141-
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
150+
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
151+
return out
152+
end
153+
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X}
154+
result = DiffResults.GradientResult(SVector(zero(X)))
155+
df = (vx, vy) -> let vy=vy
156+
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
157+
end
158+
results = $(g)(df, value(x), value(y))
159+
map!(DiffResult.value, value(out), results)
160+
cache = (results, df, index_bound(x, out), index_bound(y, out))
161+
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
162+
return out
163+
end
164+
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y}
165+
result = DiffResults.GradientResult(SVector(zero(Y)))
166+
df = let vx=vx
167+
(vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
168+
end
169+
results = $(g)(df, value(x), value(y))
170+
map!(DiffResult.value, value(out), results)
171+
cache = (results, df, index_bound(x, out), index_bound(y, out))
172+
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
142173
return out
143174
end
144175
end
@@ -166,6 +197,7 @@ for g in (:map, :broadcast), (M, f, arity) in DiffRules.diffrules(; filter_modul
166197
if arity == 1
167198
@eval @inline Base.$(g)(f::typeof($M.$f), t::TrackedArray) = $(g)(ForwardOptimize(f), t)
168199
elseif arity == 2
200+
(M, f) in SKIPPED_DIFFRULES && continue
169201
# skip these definitions if `f` is one of the functions
170202
# that will get a manually defined broadcast definition
171203
# later (see "built-in infix operations" below)
@@ -207,20 +239,52 @@ for g in (:map, :broadcast)
207239
record!(tp, SpecialInstruction, $(g), x, out, cache)
208240
return out
209241
end
210-
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
211-
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$(T){X,D}, y::$A) where {F,X,D}
212-
result = DiffResults.GradientResult(SVector(zero(X), zero(D)))
213-
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
242+
for A in ARRAY_TYPES
243+
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedReal{X,D}, y::$A) where {F,X,D}
244+
result = DiffResults.DiffResult(zero(X), zero(D))
245+
df = let result=result
246+
(vx, vy) -> let vy=vy
247+
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
248+
end
249+
end
214250
results = $(g)(df, value(x), value(y))
215251
tp = tape(x)
216252
out = track(DiffResults.value.(results), D, tp)
217253
cache = (results, df, index_bound(x, out), index_bound(y, out))
218254
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
219255
return out
220256
end
221-
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::$(T){Y,D}) where {F,Y,D}
222-
result = DiffResults.GradientResult(SVector(zero(Y), zero(D)))
223-
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
257+
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
258+
result = DiffResults.DiffResult(zero(Y), zero(D))
259+
df = let result=result
260+
(vx, vy) -> let vx=vx
261+
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
262+
end
263+
end
264+
results = $(g)(df, value(x), value(y))
265+
tp = tape(y)
266+
out = track(DiffResults.value.(results), D, tp)
267+
cache = (results, df, index_bound(x, out), index_bound(y, out))
268+
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
269+
return out
270+
end
271+
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D}
272+
result = DiffResults.GradientResult(SVector(zero(X)))
273+
df = (vx, vy) -> let vy=vy
274+
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
275+
end
276+
results = $(g)(df, value(x), value(y))
277+
tp = tape(x)
278+
out = track(DiffResults.value.(results), D, tp)
279+
cache = (results, df, index_bound(x, out), index_bound(y, out))
280+
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
281+
return out
282+
end
283+
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D}
284+
result = DiffResults.GradientResult(SVector(zero(Y)))
285+
df = (vx, vy) -> let vx=vx
286+
ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
287+
end
224288
results = $(g)(df, value(x), value(y))
225289
tp = tape(y)
226290
out = track(DiffResults.value.(results), D, tp)
@@ -291,8 +355,15 @@ end
291355
diffresult_increment_deriv!(input, output_deriv, results, 1)
292356
else
293357
a, b = input
294-
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
295-
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
358+
p = 0
359+
if istracked(a)
360+
p += 1
361+
diffresult_increment_deriv!(a, output_deriv, results, p)
362+
end
363+
if istracked(b)
364+
p += 1
365+
diffresult_increment_deriv!(b, output_deriv, results, p)
366+
end
296367
end
297368
unseed!(output)
298369
return nothing
@@ -311,12 +382,25 @@ end
311382
end
312383
else
313384
a, b = input
385+
p = 0
314386
if size(a) == size(b)
315-
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
316-
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
387+
if istracked(a)
388+
p += 1
389+
diffresult_increment_deriv!(a, output_deriv, results, p)
390+
end
391+
if istracked(b)
392+
p += 1
393+
diffresult_increment_deriv!(b, output_deriv, results, p)
394+
end
317395
else
318-
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1, a_bound)
319-
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound)
396+
if istracked(a)
397+
p += 1
398+
diffresult_increment_deriv!(a, output_deriv, results, p, a_bound)
399+
end
400+
if istracked(b)
401+
p += 1
402+
diffresult_increment_deriv!(b, output_deriv, results, p, b_bound)
403+
end
320404
end
321405
end
322406
unseed!(output)

src/derivatives/scalars.jl

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
77
@warn "$M.$f is not available and hence rule for it can not be defined"
88
continue # Skip rules for methods not defined in the current scope
99
end
10+
(M, f) in SKIPPED_DIFFRULES && continue
1011
if arity == 1
1112
@eval @inline $M.$(f)(t::TrackedReal) = ForwardOptimize($M.$(f))(t)
1213
elseif arity == 2

test/derivatives/ElementwiseTests.jl

+39-24
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function test_elementwise(f, fopt, x, tp)
3434
# reverse
3535
out = similar(y, (length(x), length(x)))
3636
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
37-
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x))
37+
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x); nans=true)
3838

3939
# forward
4040
x2 = x .- offset
@@ -57,7 +57,7 @@ function test_elementwise(f, fopt, x, tp)
5757
# reverse
5858
out = similar(y, (length(x), length(x)))
5959
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
60-
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x))
60+
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x); nans=true)
6161

6262
# forward
6363
x2 = x .- offset
@@ -81,9 +81,9 @@ function test_map(f, fopt, a, b, tp)
8181
@test length(tp) == 1
8282

8383
# reverse
84-
out = similar(c, (length(a), length(a)))
84+
out = similar(c, (length(c), length(a)))
8585
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
86-
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a))
86+
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a); nans=true)
8787

8888
# forward
8989
a2 = a .- offset
@@ -102,9 +102,9 @@ function test_map(f, fopt, a, b, tp)
102102
@test length(tp) == 1
103103

104104
# reverse
105-
out = similar(c, (length(a), length(a)))
105+
out = similar(c, (length(c), length(b)))
106106
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
107-
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b))
107+
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b); nans=true)
108108

109109
# forward
110110
b2 = b .- offset
@@ -123,13 +123,17 @@ function test_map(f, fopt, a, b, tp)
123123
@test length(tp) == 1
124124

125125
# reverse
126-
out_a = similar(c, (length(a), length(a)))
127-
out_b = similar(c, (length(a), length(a)))
126+
out_a = similar(c, (length(c), length(a)))
127+
out_b = similar(c, (length(c), length(b)))
128128
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
129129
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
130-
test_approx(out_a, ForwardDiff.jacobian(x -> map(f, x, b), a))
131-
test_approx(out_b, ForwardDiff.jacobian(x -> map(f, a, x), b))
132-
130+
jac = let a=a, b=b, f=f
131+
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
132+
map(f, reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
133+
end
134+
end
135+
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
136+
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)
133137
# forward
134138
a2, b2 = a .- offset, b .- offset
135139
ReverseDiff.value!(at, a2)
@@ -163,7 +167,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
163167
# reverse
164168
out = similar(c, (length(c), length(a)))
165169
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
166-
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a))
170+
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a); nans=true)
167171

168172
# forward
169173
a2 = a .- offset
@@ -184,7 +188,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
184188
# reverse
185189
out = similar(c, (length(c), length(b)))
186190
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
187-
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b))
191+
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b); nans=true)
188192

189193
# forward
190194
b2 = b .- offset
@@ -207,8 +211,13 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
207211
out_b = similar(c, (length(c), length(b)))
208212
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
209213
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
210-
test_approx(out_a, ForwardDiff.jacobian(x -> g(x, b), a))
211-
test_approx(out_b, ForwardDiff.jacobian(x -> g(a, x), b))
214+
jac = let a=a, b=b, g=g
215+
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
216+
g(reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
217+
end
218+
end
219+
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
220+
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)
212221

213222
# forward
214223
a2, b2 = a .- offset, b .- offset
@@ -243,7 +252,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
243252
# reverse
244253
out = similar(y)
245254
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
246-
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n))
255+
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n); nans=true)
247256

248257
# forward
249258
n2 = n + offset
@@ -264,7 +273,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
264273
# reverse
265274
out = similar(y, (length(y), length(x)))
266275
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
267-
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x))
276+
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x); nans=true)
268277

269278
# forward
270279
x2 = x .- offset
@@ -287,8 +296,11 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
287296
out_x = similar(y, (length(y), length(x)))
288297
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
289298
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
290-
test_approx(out_n, ForwardDiff.derivative(z -> g(z, x), n))
291-
test_approx(out_x, ForwardDiff.jacobian(z -> g(n, z), x))
299+
jac = let x=x, g=g
300+
ForwardDiff.jacobian(z -> g(z[1], reshape(z[2:end], size(x))), vcat(n, vec(x)))
301+
end
302+
test_approx(out_n, reshape(jac[:, 1], size(y)); nans=true)
303+
test_approx(out_x, jac[:, 2:end]; nans=true)
292304

293305
# forward
294306
n2, x2 = n + offset , x .- offset
@@ -323,7 +335,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
323335
# reverse
324336
out = similar(y)
325337
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
326-
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n))
338+
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n); nans=true)
327339

328340
# forward
329341
n2 = n + offset
@@ -344,7 +356,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
344356
# reverse
345357
out = similar(y, (length(y), length(x)))
346358
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
347-
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x))
359+
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x); nans=true)
348360

349361
# forward
350362
x2 = x .- offset
@@ -367,8 +379,11 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
367379
out_x = similar(y, (length(y), length(x)))
368380
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
369381
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
370-
test_approx(out_n, ForwardDiff.derivative(z -> g(x, z), n))
371-
test_approx(out_x, ForwardDiff.jacobian(z -> g(z, n), x))
382+
jac = let x=x, g=g
383+
ForwardDiff.jacobian(z -> g(reshape(z[1:(end - 1)], size(x)), z[end]), vcat(vec(x), n))
384+
end
385+
test_approx(out_x, jac[:, 1:(end - 1)]; nans=true)
386+
test_approx(out_n, reshape(jac[:, end], size(y)); nans=true)
372387

373388
# forward
374389
x2, n2 = x .- offset, n + offset
@@ -393,7 +408,7 @@ for (M, fsym, arity) in DiffRules.diffrules(; filter_modules=nothing)
393408
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), fsym))
394409
error("$M.$fsym is not available")
395410
end
396-
fsym === :rem2pi && continue
411+
(M, fsym) in ReverseDiff.SKIPPED_DIFFRULES && continue
397412
if arity == 1
398413
f = eval(:($M.$fsym))
399414
test_println("forward-mode unary scalar functions", f)

0 commit comments

Comments
 (0)