Skip to content

Commit a78ae74

Browse files
committed
broadcasting perf fix
1 parent d1b9f22 commit a78ae74

File tree

3 files changed

+320
-3
lines changed

3 files changed

+320
-3
lines changed

src/ReverseDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("tape.jl")
2929
include("tracked.jl")
3030
include("macros.jl")
3131
include("derivatives/arrays.jl")
32+
include("derivatives/broadcast.jl")
3233
include("derivatives/propagation.jl")
3334
include("derivatives/scalars.jl")
3435
include("derivatives/elementwise.jl")

src/derivatives/broadcast.jl

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
##################
2+
## Broadcasting ##
3+
##################
4+
5+
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
6+
using ForwardDiff: ForwardDiff, Dual
7+
import Base.Broadcast: materialize
8+
const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T}
9+
10+
"""
11+
NotTracked(f::Function)
12+
13+
A struct that can be used to wrap around closures, structs and arrays of structs declaring that they do not contain tracked variables. This enables a more efficient broadcasting of such functions and structs when doing automatic differentiation with `ReverseDiff` producing a `TrackedArray` instead of an `Array{<:TrackedReal}`.
14+
"""
15+
struct NotTracked{F} <: Function
16+
f::F
17+
end
18+
(f::NotTracked{<:Union{Function, Type}})(args...; kwargs...) = f.f(args...; kwargs...)
19+
20+
istypeorclosure(::F) where {F} = _istypeorclosure(F)
21+
istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F)
22+
istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F)
23+
istypeorclosure(::AbstractArray{<:Real}) = false
24+
istypeorclosure(::TrackedArray) = false
25+
istypeorclosure(::AbstractArray{<:TrackedReal}) = true
26+
istypeorclosure(::Real) = false
27+
@generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0))
28+
29+
mayhavetracked(b) = istypeorclosure(b)
30+
mayhavetracked(b::Type) = false
31+
mayhavetracked(b::NotTracked) = false
32+
mayhavetracked(b::Base.RefValue{<:NotTracked}) = false
33+
mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args)
34+
35+
struct TrackedStyle <: BroadcastStyle end
36+
37+
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray, TrackedReal}}) = TrackedStyle()
38+
Broadcast.BroadcastStyle(::TrackedStyle, b::BroadcastStyle) = TrackedStyle()
39+
40+
# We have to re-build the original broadcast struct to get the appropriate array
41+
# style. We need this primarily to support CuArrays' broadcasting fixes.
42+
broadcast_rebuild(xs) = value(xs)
43+
function broadcast_rebuild(bc::Broadcasted)
44+
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
45+
end
46+
47+
getstyle(::Broadcasted{Style}) where {Style} = Style
48+
remove_not_tracked(f) = f
49+
remove_not_tracked(f::NotTracked) = f.f
50+
remove_not_tracked(f::Base.RefValue{<:NotTracked}) = Ref(remove_not_tracked(f[]))
51+
remove_not_tracked(f::Base.RefValue{<:NotTracked{<:AbstractArray}}) = remove_not_tracked(f[])
52+
function remove_not_tracked(b::Broadcasted{style}) where {style}
53+
return Broadcasted{style}(remove_not_tracked(b.f), remove_not_tracked.(b.args), b.axes)
54+
end
55+
56+
onlyrealarrays(args::Tuple) = onlyrealarray(first(args)) && onlyrealarrays(Base.tail(args))
57+
onlyrealarrays(::Tuple{}) = true
58+
onlyrealarray(::AbstractArray{<:Real}) = true
59+
onlyrealarray(::AbstractArray) = false
60+
onlyrealarray(::Any) = true
61+
62+
anyreals(args::Tuple) = first(args) isa Real || anyreals(Base.tail(args))
63+
anyreals(args::Tuple{}) = false
64+
65+
function get_implementation(bc, f, T, args)
66+
outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{})
67+
# Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals,
68+
# Output is real, and
69+
# No tracked closure or arguments, except TrackedReal and TrackedArray.
70+
if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args))
71+
return Val(:tracker)
72+
# No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals,
73+
# Output is real, and
74+
# No tracked closure or arguments, except TrackedReal and TrackedArray.
75+
elseif !mayhavetracked(bc) && outputisreal
76+
return Val(:reversediff)
77+
# Function or any arg is possibly a tracked non-real or an array of tracked reals/non-reals,
78+
# Or output is not an array of reals
79+
else
80+
return Val(:fallback)
81+
end
82+
end
83+
function Base.copy(_bc::Broadcasted{TrackedStyle})
84+
bc = remove_not_tracked(_bc)
85+
flattened_bc = Broadcast.flatten(bc)
86+
untracked_bc = broadcast_rebuild(bc)
87+
flattened_untracked_bc = Broadcast.flatten(untracked_bc)
88+
T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)})
89+
f, args = flattened_untracked_bc.f, flattened_bc.args
90+
implementation = get_implementation(_bc, f, T, args)
91+
if implementation isa Val{:reversediff}
92+
return ∇broadcast(f, args...)
93+
elseif implementation isa Val{:tracker}
94+
return tracker_∇broadcast(f, args...)
95+
else
96+
style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes
97+
return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes))
98+
end
99+
end
100+
101+
# https://github.com/FluxML/Flux.jl/issues/353
102+
if VERSION < v"1.1.0-DEV.548"
103+
@eval Base.Broadcast begin
104+
function flatten(bc::Broadcasted{Style}) where {Style}
105+
isflat(bc) && return bc
106+
args = cat_nested(bc)
107+
let makeargs = make_makeargs(bc), f = bc.f
108+
newf = @inline function(args::Vararg{Any,N}) where N
109+
f(makeargs(args...)...)
110+
end
111+
return Broadcasted{Style}(newf, args, bc.axes)
112+
end
113+
end
114+
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
115+
bc = t[1]
116+
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
117+
let makeargs = make_makeargs(makeargs, bc.args)
118+
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
119+
return @inline function(args::Vararg{Any,N}) where N
120+
args1 = makeargs(args...)
121+
a, b = headargs(args1...), tailargs(args1...)
122+
(f(a...), b...)
123+
end
124+
end
125+
end
126+
end
127+
end
128+
end
129+
130+
getouttype(::TrackedReal{<:Any, D}) where {D} = D
131+
getouttype(::TrackedArray{<:Any, D}) where {D} = D
132+
getouttype(::Any) = Union{}
133+
134+
deref(x) = x
135+
deref(x::Base.RefValue) = x[]
136+
137+
@generated function splatcall(f, x::SVector{N}, utargs::T, ::Val{tinds}) where {N, T <: Tuple, tinds}
138+
args = []
139+
ti = 1
140+
uti = 1
141+
for i in 1:(N + length(T.types))
142+
if i in tinds
143+
push!(args, :(deref(x[$ti])))
144+
ti += 1
145+
else
146+
push!(args, :(deref(utargs[$uti])))
147+
uti += 1
148+
end
149+
end
150+
return quote
151+
$(Expr(:meta, :inline))
152+
$(Expr(:call, :f, args...))
153+
end
154+
end
155+
156+
@generated function splitargs(args::T) where {T <: Tuple}
157+
N = length(T.types)
158+
RealOrArray = Union{Real, AbstractArray}
159+
inds = [i for i in 1:N if T.types[i] <: RealOrArray]
160+
indsval = :(Val{$(Expr(:tuple, [:($i) for i in inds]...))}())
161+
maybetracked = Expr(:tuple, [:(args[$i]) for i in inds]...)
162+
untracked = Expr(:tuple, [:(args[$i]) for i in 1:N if !(i in inds)]...)
163+
return :($indsval, $maybetracked, $untracked)
164+
end
165+
166+
## A generalization of the broadcasting approach in ReverseDiff for general functions
167+
168+
@inline function ∇broadcast(f::F, args::Vararg{<:Any}) where {F}
169+
inds, targs, untracked = splitargs(args)
170+
N = length(targs)
171+
D = promote_type(getouttype.(targs)...)
172+
result = DiffResults.GradientResult(zero(SVector{N, D}))
173+
function df(x...)
174+
return ForwardDiff.gradient!(
175+
result,
176+
s -> splatcall(f, s, untracked, inds),
177+
SVector(x),
178+
)
179+
end
180+
results = broadcast(df, value.(targs)...)
181+
tp = tape(targs...)
182+
out_value = DiffResults.value.(results)
183+
eltype(out_value) == Bool && return out_value
184+
out = track(out_value, D, tp)
185+
cache = (results, df, index_bound.(targs, (out,)))
186+
record!(tp, SpecialInstruction, ∇broadcast, targs, out, cache)
187+
return out
188+
end
189+
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(∇broadcast)})
190+
input = instruction.input
191+
output = instruction.output
192+
output_deriv = deriv(output)
193+
results, _, bounds = instruction.cache
194+
N = length(input)
195+
if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input)))
196+
_add_to_deriv!(input, output_deriv, results)
197+
else
198+
_add_to_deriv!(input, output_deriv, results, bounds)
199+
end
200+
unseed!(output)
201+
return nothing
202+
end
203+
204+
@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple}
205+
N = length(T.types)
206+
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
207+
end
208+
_add_to_deriv!(_, _, _, _) = nothing
209+
function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
210+
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i)
211+
end
212+
213+
@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
214+
N = length(T.types)
215+
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
216+
end
217+
_add_to_deriv!(_, _, _, _, _) = nothing
218+
function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
219+
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound)
220+
end
221+
222+
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(∇broadcast)})
223+
input, output = instruction.input, instruction.output
224+
results, df, _ = instruction.cache
225+
broadcast!(df, results, value.(input)...)
226+
output_value = value(output)
227+
output_value .= DiffResults.value.(results)
228+
return nothing
229+
end
230+
231+
## Tracker style broadcasting
232+
## Good for broadcasting real numbers or arrays of non-tracked structs
233+
234+
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
235+
236+
unbroadcast(x::AbstractArray, Δ) =
237+
size(x) == size(Δ) ? Δ :
238+
length(x) == length(Δ) ? trim(x, Δ) :
239+
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
240+
241+
unbroadcast(x::Number, Δ) = sum(Δ)
242+
unbroadcast(x::Base.RefValue, _) = nothing
243+
244+
dual(x, p) = x
245+
dual(x::Real, p) = Dual(x, p)
246+
247+
function _deriv(f, G, ::Val{i}, args::Vararg{Any, N}) where {N, i}
248+
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
249+
return f(dargs...).partials[1] * G
250+
end
251+
@generated function _derivs(f, G, args::Vararg{Any, N}) where {N}
252+
return Expr(:tuple, [:(_deriv.(f, G, Val($i), args...)) for i in 1:N]...)
253+
end
254+
@inline function tracker_∇broadcast(f, args::Vararg{Any, N}) where {N}
255+
args_values = map(value, args)
256+
out_value = broadcast(f, args_values...)
257+
tp = tape(args...)
258+
eltype(out_value) == Bool && return out_value
259+
out = track(out_value, tp)
260+
cache = (f,)
261+
record!(tp, SpecialInstruction, tracker_∇broadcast, args, out, cache)
262+
return out
263+
end
264+
265+
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)})
266+
input, output = instruction.input, instruction.output
267+
f = instruction.cache[1]
268+
output_value = value(output)
269+
broadcast!(f, output_value, value.(input)...)
270+
return nothing
271+
end
272+
273+
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)})
274+
input = instruction.input
275+
output = instruction.output
276+
f = instruction.cache[1]
277+
output_deriv = deriv(output)
278+
N = length(input)
279+
Δargs = _derivs(f, output_deriv, value.(input)...)
280+
dxs = map(unbroadcast, input, Δargs)
281+
map(_add_to_deriv!, input, dxs)
282+
unseed!(output)
283+
return nothing
284+
end
285+
286+
## Limited ReverseDiff broadcasting
287+
## Efficient broadcasting for specific functions, e.g. +, -
288+
289+
@inline _materialize(f, args) = broadcast(f, args...)
290+
291+
for (M, f, arity) in DiffRules.diffrules()
292+
isdefined(ReverseDiff, M) || continue
293+
if arity == 1
294+
@eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray}}) = _materialize(bc.f, bc.args)
295+
elseif arity == 2
296+
@eval begin
297+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedArray}}) = _materialize(bc.f, bc.args)
298+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedReal}}) = _materialize(bc.f, bc.args)
299+
@noinline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,TrackedArray}}) = _materialize(bc.f, bc.args)
300+
end
301+
for A in ARRAY_TYPES
302+
@eval begin
303+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A,TrackedArray}}) = _materialize(bc.f, bc.args)
304+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray, $A}}) = _materialize(bc.f, bc.args)
305+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, TrackedReal}}) = _materialize(bc.f, bc.args)
306+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,$A}}) = _materialize(bc.f, bc.args)
307+
end
308+
end
309+
for R in REAL_TYPES
310+
@eval begin
311+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R,TrackedArray}}) = _materialize(bc.f, bc.args)
312+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,$R}}) = _materialize(bc.f, bc.args)
313+
end
314+
end
315+
end
316+
end

src/macros.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ macro grad(expr)
215215
back = instruction.cache[1]
216216
input_derivs = back($ReverseDiff.deriv(output))
217217
@assert input_derivs isa Tuple
218-
$ReverseDiff.add_to_deriv!.(input, input_derivs)
218+
$ReverseDiff._add_to_deriv!.(input, input_derivs)
219219
$ReverseDiff.unseed!(output)
220220
return nothing
221221
end
@@ -237,8 +237,8 @@ macro grad(expr)
237237
end
238238
end |> esc
239239
end
240-
add_to_deriv!(d1, d2) = nothing
241-
function add_to_deriv!(d1::Union{TrackedReal, TrackedArray}, d2)
240+
_add_to_deriv!(d1, d2) = nothing
241+
function _add_to_deriv!(d1::Union{TrackedReal, TrackedArray}, d2)
242242
increment_deriv!(d1, d2)
243243
end
244244
function getargs_expr(args_with_types)

0 commit comments

Comments
 (0)