Skip to content

Commit c10db62

Browse files
authored
Merge pull request #134 from JuliaDiff/mt/broadcasting
broadcasting perf fix
2 parents 31cf77e + 4097259 commit c10db62

File tree

3 files changed

+325
-3
lines changed

3 files changed

+325
-3
lines changed

src/ReverseDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("tape.jl")
2929
include("tracked.jl")
3030
include("macros.jl")
3131
include("derivatives/arrays.jl")
32+
include("derivatives/broadcast.jl")
3233
include("derivatives/propagation.jl")
3334
include("derivatives/scalars.jl")
3435
include("derivatives/elementwise.jl")

src/derivatives/broadcast.jl

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
##################
2+
## Broadcasting ##
3+
##################
4+
5+
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
6+
using ForwardDiff: ForwardDiff, Dual
7+
import Base.Broadcast: materialize
8+
const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T}
9+
10+
"""
11+
NotTracked(f::Function)
12+
13+
A struct that can be used to wrap around closures, structs and arrays of structs declaring that they do not contain tracked variables. This enables a more efficient broadcasting of such functions and structs when doing automatic differentiation with `ReverseDiff` producing a `TrackedArray` instead of an `Array{<:TrackedReal}`.
14+
"""
15+
struct NotTracked{F} <: Function
16+
f::F
17+
end
18+
(f::NotTracked{<:Union{Function, Type}})(args...; kwargs...) = f.f(args...; kwargs...)
19+
20+
istypeorclosure(::F) where {F} = _istypeorclosure(F)
21+
istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F)
22+
istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F)
23+
istypeorclosure(::AbstractArray{<:Real}) = false
24+
istypeorclosure(::TrackedArray) = false
25+
istypeorclosure(::AbstractArray{<:TrackedReal}) = true
26+
istypeorclosure(::Real) = false
27+
@generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0))
28+
29+
mayhavetracked(b) = istypeorclosure(b)
30+
mayhavetracked(b::Type) = false
31+
mayhavetracked(b::NotTracked) = false
32+
mayhavetracked(b::Base.RefValue{<:NotTracked}) = false
33+
mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args)
34+
35+
struct TrackedStyle <: BroadcastStyle end
36+
37+
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray, TrackedReal}}) = TrackedStyle()
38+
Broadcast.BroadcastStyle(::TrackedStyle, b::BroadcastStyle) = TrackedStyle()
39+
40+
# We have to re-build the original broadcast struct to get the appropriate array
41+
# style. We need this primarily to support CuArrays' broadcasting fixes.
42+
broadcast_rebuild(xs) = recur_value(xs)
43+
recur_value(xs) = xs
44+
recur_value(xs::Union{TrackedReal, TrackedArray}) = recur_value(value(xs))
45+
46+
function broadcast_rebuild(bc::Broadcasted)
47+
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
48+
end
49+
50+
getstyle(::Broadcasted{Style}) where {Style} = Style
51+
remove_not_tracked(f) = f
52+
remove_not_tracked(f::NotTracked) = f.f
53+
remove_not_tracked(f::Base.RefValue{<:NotTracked}) = Ref(remove_not_tracked(f[]))
54+
remove_not_tracked(f::Base.RefValue{<:NotTracked{<:AbstractArray}}) = remove_not_tracked(f[])
55+
function remove_not_tracked(b::Broadcasted{style}) where {style}
56+
return Broadcasted{style}(remove_not_tracked(b.f), remove_not_tracked.(b.args), b.axes)
57+
end
58+
59+
onlyrealarrays(args::Tuple) = onlyrealarray(first(args)) && onlyrealarrays(Base.tail(args))
60+
onlyrealarrays(::Tuple{}) = true
61+
onlyrealarray(::AbstractArray{<:Real}) = true
62+
onlyrealarray(::AbstractArray) = false
63+
onlyrealarray(::Any) = true
64+
65+
anyreals(args::Tuple) = first(args) isa Real || anyreals(Base.tail(args))
66+
anyreals(args::Tuple{}) = false
67+
68+
function get_implementation(bc, f, T, args)
69+
outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{})
70+
# Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals,
71+
# Output is real, and
72+
# No tracked closure or arguments, except TrackedReal and TrackedArray.
73+
if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args))
74+
return Val(:tracker)
75+
# No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals,
76+
# Output is real, and
77+
# No tracked closure or arguments, except TrackedReal and TrackedArray.
78+
elseif !mayhavetracked(bc) && outputisreal
79+
return Val(:reversediff)
80+
# Function or any arg is possibly a tracked non-real or an array of tracked reals/non-reals,
81+
# Or output is not an array of reals
82+
else
83+
return Val(:fallback)
84+
end
85+
end
86+
function Base.copy(_bc::Broadcasted{TrackedStyle})
87+
bc = remove_not_tracked(_bc)
88+
flattened_bc = Broadcast.flatten(bc)
89+
untracked_bc = broadcast_rebuild(bc)
90+
flattened_untracked_bc = Broadcast.flatten(untracked_bc)
91+
T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)})
92+
f, args = flattened_untracked_bc.f, flattened_bc.args
93+
implementation = get_implementation(_bc, f, T, args)
94+
if implementation isa Val{:reversediff}
95+
return ∇broadcast(f, args...)
96+
elseif implementation isa Val{:tracker}
97+
return tracker_∇broadcast(f, args...)
98+
else
99+
style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes
100+
return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes))
101+
end
102+
end
103+
104+
# https://github.com/FluxML/Flux.jl/issues/353
105+
if VERSION < v"1.1.0-DEV.548"
106+
@eval Base.Broadcast begin
107+
function flatten(bc::Broadcasted{Style}) where {Style}
108+
isflat(bc) && return bc
109+
args = cat_nested(bc)
110+
let makeargs = make_makeargs(bc), f = bc.f
111+
newf = @inline function(args::Vararg{Any,N}) where N
112+
f(makeargs(args...)...)
113+
end
114+
return Broadcasted{Style}(newf, args, bc.axes)
115+
end
116+
end
117+
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
118+
bc = t[1]
119+
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
120+
let makeargs = make_makeargs(makeargs, bc.args)
121+
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
122+
return @inline function(args::Vararg{Any,N}) where N
123+
args1 = makeargs(args...)
124+
a, b = headargs(args1...), tailargs(args1...)
125+
(f(a...), b...)
126+
end
127+
end
128+
end
129+
end
130+
end
131+
end
132+
133+
getouttype(::TrackedReal{<:Any, D}) where {D} = D
134+
getouttype(::TrackedArray{<:Any, D}) where {D} = D
135+
getouttype(::Any) = Union{}
136+
137+
deref(x) = x
138+
deref(x::Base.RefValue) = x[]
139+
140+
@generated function splatcall(f, x::SVector{N}, utargs::T, ::Val{tinds}) where {N, T <: Tuple, tinds}
141+
args = []
142+
ti = 1
143+
uti = 1
144+
for i in 1:(N + length(T.types))
145+
if i in tinds
146+
push!(args, :(deref(x[$ti])))
147+
ti += 1
148+
else
149+
push!(args, :(deref(utargs[$uti])))
150+
uti += 1
151+
end
152+
end
153+
return quote
154+
$(Expr(:meta, :inline))
155+
$(Expr(:call, :f, args...))
156+
end
157+
end
158+
159+
@generated function splitargs(args::T) where {T <: Tuple}
160+
N = length(T.types)
161+
RealOrArray = Union{Real, AbstractArray}
162+
inds = [i for i in 1:N if T.types[i] <: RealOrArray]
163+
indsval = :(Val{$(Expr(:tuple, [:($i) for i in inds]...))}())
164+
maybetracked = Expr(:tuple, [:(args[$i]) for i in inds]...)
165+
untracked = Expr(:tuple, [:(args[$i]) for i in 1:N if !(i in inds)]...)
166+
return :($indsval, $maybetracked, $untracked)
167+
end
168+
169+
## A generalization of the broadcasting approach in ReverseDiff for general functions
170+
171+
@inline function ∇broadcast(f::F, args::Vararg{<:Any}) where {F}
172+
inds, targs, untracked = splitargs(args)
173+
N = length(targs)
174+
D = promote_type(getouttype.(targs)...)
175+
result = DiffResults.GradientResult(zero(SVector{N, D}))
176+
function df(x...)
177+
return ForwardDiff.gradient!(
178+
result,
179+
s -> splatcall(f, s, untracked, inds),
180+
SVector(x),
181+
)
182+
end
183+
results = broadcast(df, value.(targs)...)
184+
tp = tape(targs...)
185+
out_value = DiffResults.value.(results)
186+
eltype(out_value) == Bool && return out_value
187+
out = track(out_value, D, tp)
188+
cache = (results, df, index_bound.(targs, (out,)))
189+
record!(tp, SpecialInstruction, ∇broadcast, targs, out, cache)
190+
return out
191+
end
192+
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(∇broadcast)})
193+
input = instruction.input
194+
output = instruction.output
195+
output_deriv = deriv(output)
196+
results, _, bounds = instruction.cache
197+
N = length(input)
198+
if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input)))
199+
_add_to_deriv!(input, output_deriv, results)
200+
else
201+
_add_to_deriv!(input, output_deriv, results, bounds)
202+
end
203+
unseed!(output)
204+
return nothing
205+
end
206+
207+
@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple}
208+
N = length(T.types)
209+
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
210+
end
211+
_add_to_deriv!(_, _, _, _) = nothing
212+
function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
213+
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i)
214+
end
215+
216+
@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
217+
N = length(T.types)
218+
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
219+
end
220+
_add_to_deriv!(_, _, _, _, _) = nothing
221+
function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
222+
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound)
223+
end
224+
225+
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(∇broadcast)})
226+
input, output = instruction.input, instruction.output
227+
results, df, _ = instruction.cache
228+
pull_value!.(input)
229+
broadcast!(df, results, value.(input)...)
230+
output_value = value(output)
231+
output_value .= DiffResults.value.(results)
232+
return nothing
233+
end
234+
235+
## Tracker style broadcasting
236+
## Good for broadcasting real numbers or arrays of non-tracked structs
237+
238+
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
239+
240+
unbroadcast(x::AbstractArray, Δ) =
241+
size(x) == size(Δ) ? Δ :
242+
length(x) == length(Δ) ? trim(x, Δ) :
243+
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
244+
245+
unbroadcast(x::Number, Δ) = sum(Δ)
246+
unbroadcast(x::Base.RefValue, _) = nothing
247+
248+
dual(x, p) = x
249+
dual(x::Real, p) = Dual(x, p)
250+
251+
function _deriv(f, G, ::Val{i}, args::Vararg{Any, N}) where {N, i}
252+
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
253+
return f(dargs...).partials[1] * G
254+
end
255+
@generated function _derivs(f, G, args::Vararg{Any, N}) where {N}
256+
return Expr(:tuple, [:(_deriv.(f, G, Val($i), args...)) for i in 1:N]...)
257+
end
258+
@inline function tracker_∇broadcast(f, args::Vararg{Any, N}) where {N}
259+
args_values = map(value, args)
260+
out_value = broadcast(f, args_values...)
261+
tp = tape(args...)
262+
eltype(out_value) == Bool && return out_value
263+
out = track(out_value, tp)
264+
cache = (f,)
265+
record!(tp, SpecialInstruction, tracker_∇broadcast, args, out, cache)
266+
return out
267+
end
268+
269+
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)})
270+
input, output = instruction.input, instruction.output
271+
f = instruction.cache[1]
272+
output_value = value(output)
273+
pull_value!.(input)
274+
broadcast!(f, output_value, value.(input)...)
275+
return nothing
276+
end
277+
278+
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)})
279+
input = instruction.input
280+
output = instruction.output
281+
f = instruction.cache[1]
282+
output_deriv = deriv(output)
283+
N = length(input)
284+
Δargs = _derivs(f, output_deriv, value.(input)...)
285+
dxs = map(unbroadcast, input, Δargs)
286+
map(_add_to_deriv!, input, dxs)
287+
unseed!(output)
288+
return nothing
289+
end
290+
291+
## Limited ReverseDiff broadcasting
292+
## Efficient broadcasting for specific functions, e.g. +, -
293+
294+
@inline _materialize(f, args) = broadcast(f, args...)
295+
296+
for (M, f, arity) in DiffRules.diffrules()
297+
isdefined(ReverseDiff, M) || continue
298+
if arity == 1
299+
@eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray}}) = _materialize(bc.f, bc.args)
300+
elseif arity == 2
301+
@eval begin
302+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedArray}}) = _materialize(bc.f, bc.args)
303+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedReal}}) = _materialize(bc.f, bc.args)
304+
@noinline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,TrackedArray}}) = _materialize(bc.f, bc.args)
305+
end
306+
for A in ARRAY_TYPES
307+
@eval begin
308+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A,TrackedArray}}) = _materialize(bc.f, bc.args)
309+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray, $A}}) = _materialize(bc.f, bc.args)
310+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, TrackedReal}}) = _materialize(bc.f, bc.args)
311+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,$A}}) = _materialize(bc.f, bc.args)
312+
end
313+
end
314+
for R in REAL_TYPES
315+
@eval begin
316+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R,TrackedArray}}) = _materialize(bc.f, bc.args)
317+
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,$R}}) = _materialize(bc.f, bc.args)
318+
end
319+
end
320+
end
321+
end

src/macros.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ macro grad(expr)
215215
back = instruction.cache[1]
216216
input_derivs = back($ReverseDiff.deriv(output))
217217
@assert input_derivs isa Tuple
218-
$ReverseDiff.add_to_deriv!.(input, input_derivs)
218+
$ReverseDiff._add_to_deriv!.(input, input_derivs)
219219
$ReverseDiff.unseed!(output)
220220
return nothing
221221
end
@@ -237,8 +237,8 @@ macro grad(expr)
237237
end
238238
end |> esc
239239
end
240-
add_to_deriv!(d1, d2) = nothing
241-
function add_to_deriv!(d1::Union{TrackedReal, TrackedArray}, d2)
240+
_add_to_deriv!(d1, d2) = nothing
241+
function _add_to_deriv!(d1::Union{TrackedReal, TrackedArray}, d2)
242242
increment_deriv!(d1, d2)
243243
end
244244
function getargs_expr(args_with_types)

0 commit comments

Comments
 (0)