|
| 1 | +################## |
| 2 | +## Broadcasting ## |
| 3 | +################## |
| 4 | + |
| 5 | +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted |
| 6 | +using ForwardDiff: ForwardDiff, Dual |
| 7 | +import Base.Broadcast: materialize |
| 8 | +const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} |
| 9 | + |
| 10 | +""" |
| 11 | + NotTracked(f::Function) |
| 12 | +
|
| 13 | +A struct that can be used to wrap around closures, structs and arrays of structs declaring that they do not contain tracked variables. This enables a more efficient broadcasting of such functions and structs when doing automatic differentiation with `ReverseDiff` producing a `TrackedArray` instead of an `Array{<:TrackedReal}`. |
| 14 | +""" |
| 15 | +struct NotTracked{F} <: Function |
| 16 | + f::F |
| 17 | +end |
| 18 | +(f::NotTracked{<:Union{Function, Type}})(args...; kwargs...) = f.f(args...; kwargs...) |
| 19 | + |
| 20 | +istypeorclosure(::F) where {F} = _istypeorclosure(F) |
| 21 | +istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F) |
| 22 | +istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F) |
| 23 | +istypeorclosure(::AbstractArray{<:Real}) = false |
| 24 | +istypeorclosure(::TrackedArray) = false |
| 25 | +istypeorclosure(::AbstractArray{<:TrackedReal}) = true |
| 26 | +istypeorclosure(::Real) = false |
| 27 | +@generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) |
| 28 | + |
| 29 | +mayhavetracked(b) = istypeorclosure(b) |
| 30 | +mayhavetracked(b::Type) = false |
| 31 | +mayhavetracked(b::NotTracked) = false |
| 32 | +mayhavetracked(b::Base.RefValue{<:NotTracked}) = false |
| 33 | +mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args) |
| 34 | + |
| 35 | +struct TrackedStyle <: BroadcastStyle end |
| 36 | + |
| 37 | +Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray, TrackedReal}}) = TrackedStyle() |
| 38 | +Broadcast.BroadcastStyle(::TrackedStyle, b::BroadcastStyle) = TrackedStyle() |
| 39 | + |
| 40 | +# We have to re-build the original broadcast struct to get the appropriate array |
| 41 | +# style. We need this primarily to support CuArrays' broadcasting fixes. |
| 42 | +broadcast_rebuild(xs) = recur_value(xs) |
| 43 | +recur_value(xs) = xs |
| 44 | +recur_value(xs::Union{TrackedReal, TrackedArray}) = recur_value(value(xs)) |
| 45 | + |
| 46 | +function broadcast_rebuild(bc::Broadcasted) |
| 47 | + broadcasted(bc.f, broadcast_rebuild.(bc.args)...) |
| 48 | +end |
| 49 | + |
| 50 | +getstyle(::Broadcasted{Style}) where {Style} = Style |
| 51 | +remove_not_tracked(f) = f |
| 52 | +remove_not_tracked(f::NotTracked) = f.f |
| 53 | +remove_not_tracked(f::Base.RefValue{<:NotTracked}) = Ref(remove_not_tracked(f[])) |
| 54 | +remove_not_tracked(f::Base.RefValue{<:NotTracked{<:AbstractArray}}) = remove_not_tracked(f[]) |
| 55 | +function remove_not_tracked(b::Broadcasted{style}) where {style} |
| 56 | + return Broadcasted{style}(remove_not_tracked(b.f), remove_not_tracked.(b.args), b.axes) |
| 57 | +end |
| 58 | + |
| 59 | +onlyrealarrays(args::Tuple) = onlyrealarray(first(args)) && onlyrealarrays(Base.tail(args)) |
| 60 | +onlyrealarrays(::Tuple{}) = true |
| 61 | +onlyrealarray(::AbstractArray{<:Real}) = true |
| 62 | +onlyrealarray(::AbstractArray) = false |
| 63 | +onlyrealarray(::Any) = true |
| 64 | + |
| 65 | +anyreals(args::Tuple) = first(args) isa Real || anyreals(Base.tail(args)) |
| 66 | +anyreals(args::Tuple{}) = false |
| 67 | + |
| 68 | +function get_implementation(bc, f, T, args) |
| 69 | + outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{}) |
| 70 | + # Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals, |
| 71 | + # Output is real, and |
| 72 | + # No tracked closure or arguments, except TrackedReal and TrackedArray. |
| 73 | + if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args)) |
| 74 | + return Val(:tracker) |
| 75 | + # No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals, |
| 76 | + # Output is real, and |
| 77 | + # No tracked closure or arguments, except TrackedReal and TrackedArray. |
| 78 | + elseif !mayhavetracked(bc) && outputisreal |
| 79 | + return Val(:reversediff) |
| 80 | + # Function or any arg is possibly a tracked non-real or an array of tracked reals/non-reals, |
| 81 | + # Or output is not an array of reals |
| 82 | + else |
| 83 | + return Val(:fallback) |
| 84 | + end |
| 85 | +end |
| 86 | +function Base.copy(_bc::Broadcasted{TrackedStyle}) |
| 87 | + bc = remove_not_tracked(_bc) |
| 88 | + flattened_bc = Broadcast.flatten(bc) |
| 89 | + untracked_bc = broadcast_rebuild(bc) |
| 90 | + flattened_untracked_bc = Broadcast.flatten(untracked_bc) |
| 91 | + T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)}) |
| 92 | + f, args = flattened_untracked_bc.f, flattened_bc.args |
| 93 | + implementation = get_implementation(_bc, f, T, args) |
| 94 | + if implementation isa Val{:reversediff} |
| 95 | + return ∇broadcast(f, args...) |
| 96 | + elseif implementation isa Val{:tracker} |
| 97 | + return tracker_∇broadcast(f, args...) |
| 98 | + else |
| 99 | + style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes |
| 100 | + return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) |
| 101 | + end |
| 102 | +end |
| 103 | + |
| 104 | +# https://github.com/FluxML/Flux.jl/issues/353 |
| 105 | +if VERSION < v"1.1.0-DEV.548" |
| 106 | + @eval Base.Broadcast begin |
| 107 | + function flatten(bc::Broadcasted{Style}) where {Style} |
| 108 | + isflat(bc) && return bc |
| 109 | + args = cat_nested(bc) |
| 110 | + let makeargs = make_makeargs(bc), f = bc.f |
| 111 | + newf = @inline function(args::Vararg{Any,N}) where N |
| 112 | + f(makeargs(args...)...) |
| 113 | + end |
| 114 | + return Broadcasted{Style}(newf, args, bc.axes) |
| 115 | + end |
| 116 | + end |
| 117 | + @inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) |
| 118 | + bc = t[1] |
| 119 | + let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f |
| 120 | + let makeargs = make_makeargs(makeargs, bc.args) |
| 121 | + headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) |
| 122 | + return @inline function(args::Vararg{Any,N}) where N |
| 123 | + args1 = makeargs(args...) |
| 124 | + a, b = headargs(args1...), tailargs(args1...) |
| 125 | + (f(a...), b...) |
| 126 | + end |
| 127 | + end |
| 128 | + end |
| 129 | + end |
| 130 | + end |
| 131 | +end |
| 132 | + |
| 133 | +getouttype(::TrackedReal{<:Any, D}) where {D} = D |
| 134 | +getouttype(::TrackedArray{<:Any, D}) where {D} = D |
| 135 | +getouttype(::Any) = Union{} |
| 136 | + |
| 137 | +deref(x) = x |
| 138 | +deref(x::Base.RefValue) = x[] |
| 139 | + |
| 140 | +@generated function splatcall(f, x::SVector{N}, utargs::T, ::Val{tinds}) where {N, T <: Tuple, tinds} |
| 141 | + args = [] |
| 142 | + ti = 1 |
| 143 | + uti = 1 |
| 144 | + for i in 1:(N + length(T.types)) |
| 145 | + if i in tinds |
| 146 | + push!(args, :(deref(x[$ti]))) |
| 147 | + ti += 1 |
| 148 | + else |
| 149 | + push!(args, :(deref(utargs[$uti]))) |
| 150 | + uti += 1 |
| 151 | + end |
| 152 | + end |
| 153 | + return quote |
| 154 | + $(Expr(:meta, :inline)) |
| 155 | + $(Expr(:call, :f, args...)) |
| 156 | + end |
| 157 | +end |
| 158 | + |
| 159 | +@generated function splitargs(args::T) where {T <: Tuple} |
| 160 | + N = length(T.types) |
| 161 | + RealOrArray = Union{Real, AbstractArray} |
| 162 | + inds = [i for i in 1:N if T.types[i] <: RealOrArray] |
| 163 | + indsval = :(Val{$(Expr(:tuple, [:($i) for i in inds]...))}()) |
| 164 | + maybetracked = Expr(:tuple, [:(args[$i]) for i in inds]...) |
| 165 | + untracked = Expr(:tuple, [:(args[$i]) for i in 1:N if !(i in inds)]...) |
| 166 | + return :($indsval, $maybetracked, $untracked) |
| 167 | +end |
| 168 | + |
| 169 | +## A generalization of the broadcasting approach in ReverseDiff for general functions |
| 170 | + |
| 171 | +@inline function ∇broadcast(f::F, args::Vararg{<:Any}) where {F} |
| 172 | + inds, targs, untracked = splitargs(args) |
| 173 | + N = length(targs) |
| 174 | + D = promote_type(getouttype.(targs)...) |
| 175 | + result = DiffResults.GradientResult(zero(SVector{N, D})) |
| 176 | + function df(x...) |
| 177 | + return ForwardDiff.gradient!( |
| 178 | + result, |
| 179 | + s -> splatcall(f, s, untracked, inds), |
| 180 | + SVector(x), |
| 181 | + ) |
| 182 | + end |
| 183 | + results = broadcast(df, value.(targs)...) |
| 184 | + tp = tape(targs...) |
| 185 | + out_value = DiffResults.value.(results) |
| 186 | + eltype(out_value) == Bool && return out_value |
| 187 | + out = track(out_value, D, tp) |
| 188 | + cache = (results, df, index_bound.(targs, (out,))) |
| 189 | + record!(tp, SpecialInstruction, ∇broadcast, targs, out, cache) |
| 190 | + return out |
| 191 | +end |
| 192 | +@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(∇broadcast)}) |
| 193 | + input = instruction.input |
| 194 | + output = instruction.output |
| 195 | + output_deriv = deriv(output) |
| 196 | + results, _, bounds = instruction.cache |
| 197 | + N = length(input) |
| 198 | + if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input))) |
| 199 | + _add_to_deriv!(input, output_deriv, results) |
| 200 | + else |
| 201 | + _add_to_deriv!(input, output_deriv, results, bounds) |
| 202 | + end |
| 203 | + unseed!(output) |
| 204 | + return nothing |
| 205 | +end |
| 206 | + |
| 207 | +@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple} |
| 208 | + N = length(T.types) |
| 209 | + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) |
| 210 | +end |
| 211 | +_add_to_deriv!(_, _, _, _) = nothing |
| 212 | +function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i} |
| 213 | + return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i) |
| 214 | +end |
| 215 | + |
| 216 | +@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} |
| 217 | + N = length(T.types) |
| 218 | + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) |
| 219 | +end |
| 220 | +_add_to_deriv!(_, _, _, _, _) = nothing |
| 221 | +function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i} |
| 222 | + return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound) |
| 223 | +end |
| 224 | + |
| 225 | +@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(∇broadcast)}) |
| 226 | + input, output = instruction.input, instruction.output |
| 227 | + results, df, _ = instruction.cache |
| 228 | + pull_value!.(input) |
| 229 | + broadcast!(df, results, value.(input)...) |
| 230 | + output_value = value(output) |
| 231 | + output_value .= DiffResults.value.(results) |
| 232 | + return nothing |
| 233 | +end |
| 234 | + |
| 235 | +## Tracker style broadcasting |
| 236 | +## Good for broadcasting real numbers or arrays of non-tracked structs |
| 237 | + |
| 238 | +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) |
| 239 | + |
| 240 | +unbroadcast(x::AbstractArray, Δ) = |
| 241 | + size(x) == size(Δ) ? Δ : |
| 242 | + length(x) == length(Δ) ? trim(x, Δ) : |
| 243 | + trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) |
| 244 | + |
| 245 | +unbroadcast(x::Number, Δ) = sum(Δ) |
| 246 | +unbroadcast(x::Base.RefValue, _) = nothing |
| 247 | + |
| 248 | +dual(x, p) = x |
| 249 | +dual(x::Real, p) = Dual(x, p) |
| 250 | + |
| 251 | +function _deriv(f, G, ::Val{i}, args::Vararg{Any, N}) where {N, i} |
| 252 | + dargs = ntuple(j -> dual(args[j], i==j), Val(N)) |
| 253 | + return f(dargs...).partials[1] * G |
| 254 | +end |
| 255 | +@generated function _derivs(f, G, args::Vararg{Any, N}) where {N} |
| 256 | + return Expr(:tuple, [:(_deriv.(f, G, Val($i), args...)) for i in 1:N]...) |
| 257 | +end |
| 258 | +@inline function tracker_∇broadcast(f, args::Vararg{Any, N}) where {N} |
| 259 | + args_values = map(value, args) |
| 260 | + out_value = broadcast(f, args_values...) |
| 261 | + tp = tape(args...) |
| 262 | + eltype(out_value) == Bool && return out_value |
| 263 | + out = track(out_value, tp) |
| 264 | + cache = (f,) |
| 265 | + record!(tp, SpecialInstruction, tracker_∇broadcast, args, out, cache) |
| 266 | + return out |
| 267 | +end |
| 268 | + |
| 269 | +@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)}) |
| 270 | + input, output = instruction.input, instruction.output |
| 271 | + f = instruction.cache[1] |
| 272 | + output_value = value(output) |
| 273 | + pull_value!.(input) |
| 274 | + broadcast!(f, output_value, value.(input)...) |
| 275 | + return nothing |
| 276 | +end |
| 277 | + |
| 278 | +@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)}) |
| 279 | + input = instruction.input |
| 280 | + output = instruction.output |
| 281 | + f = instruction.cache[1] |
| 282 | + output_deriv = deriv(output) |
| 283 | + N = length(input) |
| 284 | + Δargs = _derivs(f, output_deriv, value.(input)...) |
| 285 | + dxs = map(unbroadcast, input, Δargs) |
| 286 | + map(_add_to_deriv!, input, dxs) |
| 287 | + unseed!(output) |
| 288 | + return nothing |
| 289 | +end |
| 290 | + |
| 291 | +## Limited ReverseDiff broadcasting |
| 292 | +## Efficient broadcasting for specific functions, e.g. +, - |
| 293 | + |
| 294 | +@inline _materialize(f, args) = broadcast(f, args...) |
| 295 | + |
| 296 | +for (M, f, arity) in DiffRules.diffrules() |
| 297 | + isdefined(ReverseDiff, M) || continue |
| 298 | + if arity == 1 |
| 299 | + @eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray}}) = _materialize(bc.f, bc.args) |
| 300 | + elseif arity == 2 |
| 301 | + @eval begin |
| 302 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedArray}}) = _materialize(bc.f, bc.args) |
| 303 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,TrackedReal}}) = _materialize(bc.f, bc.args) |
| 304 | + @noinline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,TrackedArray}}) = _materialize(bc.f, bc.args) |
| 305 | + end |
| 306 | + for A in ARRAY_TYPES |
| 307 | + @eval begin |
| 308 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A,TrackedArray}}) = _materialize(bc.f, bc.args) |
| 309 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray, $A}}) = _materialize(bc.f, bc.args) |
| 310 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, TrackedReal}}) = _materialize(bc.f, bc.args) |
| 311 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedReal,$A}}) = _materialize(bc.f, bc.args) |
| 312 | + end |
| 313 | + end |
| 314 | + for R in REAL_TYPES |
| 315 | + @eval begin |
| 316 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R,TrackedArray}}) = _materialize(bc.f, bc.args) |
| 317 | + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray,$R}}) = _materialize(bc.f, bc.args) |
| 318 | + end |
| 319 | + end |
| 320 | + end |
| 321 | +end |
0 commit comments