|
1 | 1 |
|
2 |
| -using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk |
3 |
| -const NoT = NoTangent() |
| 2 | +using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad |
| 3 | +const NoT = ChainRulesCore.NoTangent() |
4 | 4 |
|
5 | 5 | """
|
6 | 6 | destructure(model) -> vector, reconstructor
|
@@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT
|
124 | 124 | y = _trainmap(f, ch, _trainable(x), au)
|
125 | 125 | y isa Tuple{} && return NoT
|
126 | 126 | p = ProjectTo(x)
|
127 |
| - if p isa ProjectTo # e.g. Array, NamedTuple |
128 |
| - p(y) |
129 |
| - else # p === identity for unknown structs |
| 127 | + # if p isa ProjectTo # e.g. Array, NamedTuple |
| 128 | + # p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538 |
| 129 | + if x isa Union{Number, AbstractArray} # these don't use Tangent |
| 130 | + ProjectTo(x)(unthunk(y)) |
| 131 | + else |
130 | 132 | Tangent{typeof(x), typeof(y)}(y)
|
131 | 133 | end
|
132 | 134 | end
|
@@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
|
174 | 176 | @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
|
175 | 177 | nothing, _ -> (NoT,)
|
176 | 178 | end
|
| 179 | + |
| 180 | +""" |
| 181 | + total(f, model) |
| 182 | +
|
| 183 | +Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in |
| 184 | +the model, and returns the sum. Differentiable. Counts shared weights once. |
| 185 | +
|
| 186 | +# Examples |
| 187 | +```jldoctest |
| 188 | +julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7)); |
| 189 | +
|
| 190 | +julia> total(sum, m) |
| 191 | +12.0 |
| 192 | +
|
| 193 | +julia> total(norm, m) |
| 194 | +10.0 |
| 195 | +
|
| 196 | +julia> total(length, m) == length(destructure(m)[1]) |
| 197 | +true |
| 198 | +``` |
| 199 | +""" |
| 200 | +function total(f, x) |
| 201 | + values = [] |
| 202 | + fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) |
| 203 | + sum(values) |
| 204 | +end |
| 205 | + |
| 206 | +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x) |
| 207 | + z, backs = _total_hobbit(config, f, x) |
| 208 | + total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...) |
| 209 | + z, total_back |
| 210 | +end |
| 211 | + |
| 212 | +function _total_hobbit(config::RuleConfig, f, x) |
| 213 | + values = [] |
| 214 | + backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y |
| 215 | + val, back = rrule_via_ad(config, f, y) |
| 216 | + push!(values, val) |
| 217 | + back |
| 218 | + end |
| 219 | + sum(values), backs |
| 220 | +end |
| 221 | + |
| 222 | +function _total_grad(dz, x, backs) |
| 223 | + dfs = [] |
| 224 | + dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b |
| 225 | + df, dy = b(dz) |
| 226 | + push!(dfs, df) |
| 227 | + dy |
| 228 | + end |
| 229 | + sum(dfs), dx |
| 230 | +end |
| 231 | + |
| 232 | +function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs) |
| 233 | + @warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3 |
| 234 | + function grad_back((df, dx)) |
| 235 | + df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!" |
| 236 | + (NoT, total(dx), NoT, NoT) |
| 237 | + end |
| 238 | + _total_grad(dz, x, backs), grad_back |
| 239 | +end |
0 commit comments