Skip to content

Commit b9bbdc4

Browse files
committed
add total
1 parent e60b71e commit b9bbdc4

File tree

4 files changed

+95
-7
lines changed

4 files changed

+95
-7
lines changed

docs/src/api.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ Optimisers.trainable
5252
Optimisers.isnumeric
5353
```
5454

55-
Such restrictions are also obeyed by this function for flattening a model:
55+
Such restrictions are also obeyed by this function for flattening a model,
56+
and one for applying a function to every parameter:
5657

5758
```@docs
5859
Optimisers.destructure
5960
Optimisers.Restructure
61+
Optimisers.total
6062
```
6163

6264
## Rule Definition

src/Optimisers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export AbstractRule
99
include("adjust.jl")
1010

1111
include("destructure.jl")
12-
export destructure
12+
export destructure, total
1313

1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,

src/destructure.jl

+68-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
3-
const NoT = NoTangent()
2+
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
3+
const NoT = ChainRulesCore.NoTangent()
44

55
"""
66
destructure(model) -> vector, reconstructor
@@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT
124124
y = _trainmap(f, ch, _trainable(x), au)
125125
y isa Tuple{} && return NoT
126126
p = ProjectTo(x)
127-
if p isa ProjectTo # e.g. Array, NamedTuple
128-
p(y)
129-
else # p === identity for unknown structs
127+
# if p isa ProjectTo # e.g. Array, NamedTuple
128+
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
129+
if x isa Union{Number, AbstractArray} # these don't use Tangent
130+
ProjectTo(x)(unthunk(y))
131+
else
130132
Tangent{typeof(x), typeof(y)}(y)
131133
end
132134
end
@@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
174176
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
175177
nothing, _ -> (NoT,)
176178
end
179+
180+
"""
181+
total(f, model)
182+
183+
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
184+
the model, and returns the sum. Differentiable. Counts shared weights once.
185+
186+
# Examples
187+
```jldoctest
188+
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
189+
190+
julia> total(sum, m)
191+
12.0
192+
193+
julia> total(norm, m)
194+
10.0
195+
196+
julia> total(length, m) == length(destructure(m)[1])
197+
true
198+
```
199+
"""
200+
function total(f, x)
201+
values = []
202+
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
203+
sum(values)
204+
end
205+
206+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
207+
z, backs = _total_hobbit(config, f, x)
208+
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
209+
z, total_back
210+
end
211+
212+
function _total_hobbit(config::RuleConfig, f, x)
213+
values = []
214+
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
215+
val, back = rrule_via_ad(config, f, y)
216+
push!(values, val)
217+
back
218+
end
219+
sum(values), backs
220+
end
221+
222+
function _total_grad(dz, x, backs)
223+
dfs = []
224+
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
225+
df, dy = b(dz)
226+
push!(dfs, df)
227+
dy
228+
end
229+
sum(dfs), dx
230+
end
231+
232+
function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
233+
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
234+
function grad_back((df, dx))
235+
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
236+
(NoT, total(dx), NoT, NoT)
237+
end
238+
_total_grad(dz, x, backs), grad_back
239+
end

test/destructure.jl

+23
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,26 @@ tmp1
296296
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
297297
@test bk(1.0) == (nothing,)
298298
end
299+
300+
@testset "total" begin
301+
@test total(sum, m1) == sum(1:3)
302+
@test total(prod, m2) == prod(1:3) + prod(4:6)
303+
@test total(sum, m3) == sum(1:6)
304+
@test total(sum, m4) == sum(1:6) # shared only counts once
305+
@test total(sum, m6) == 6 + 4 + im
306+
307+
@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
308+
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
309+
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
310+
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
311+
@test g6.a isa Vector{Float64}
312+
313+
@test gradient-> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
314+
@test gradient-> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)
315+
316+
@testset "second derivatives" begin
317+
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
318+
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) Zygote.hessian_dual(f3, [1,2,3.0])
319+
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
320+
end
321+
end

0 commit comments

Comments
 (0)