Skip to content

Make fmap(f, x, y) useful #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 9, 2022
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1.5' # Replace this with the minimum Julia version that your package supports.
# - '1' # automatically expands to the latest stable 1.x release of Julia
- '1.0'
- '1.6' # Replace this with the minimum Julia version that your package supports.
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "Functors"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.7"
version = "0.2.8"

[compat]
julia = "1"
Documenter = "0.27"
julia = "1"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
12 changes: 12 additions & 0 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ usually using the macro [@functor](@ref).
"""
functor

@static if VERSION >= v"1.5" # var"@functor" doesn't work on 1.0, temporarily disable
"""
@functor T
@functor T (x,)
Expand Down Expand Up @@ -65,6 +66,7 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560)
```
"""
var"@functor"
end # VERSION

"""
Functors.isleaf(x)
Expand Down Expand Up @@ -182,6 +184,16 @@ This function walks (maps) over `xs` calling the continuation `f'` to continue t
julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
```

The behaviour when the same node appears twice can be altered by giving a value
to the `prune` keyword, which is then used in place of all but the first:

```jldoctest
julia> twice = [1, 2];

julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
```
"""
fmap

Expand Down
36 changes: 11 additions & 25 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)

functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple}, x) = x, y -> y
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity

functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
Expand Down Expand Up @@ -43,12 +43,11 @@ function _default_walk(f, x)
re(map(f, func))
end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y
struct NoKeyword end

return y
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x)
end

###
Expand All @@ -74,27 +73,16 @@ end
### Vararg forms
###

function fmap(f, x, dx...; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...)
end

function functor_tuple(f, x::Tuple, dx::Tuple)
map(x, dx) do x, x̄
_default_walk(f, x, x̄)
end
end
functor_tuple(f, x, dx) = f(x, dx)
functor_tuple(f, x, ::Nothing) = x

function _default_walk(f, x, dx)
function _default_walk(f, x, ys...)
func, re = functor(x)
map(func, dx) do x, x̄
# functor_tuple(f, x, x̄)
f(x, x̄)
end |> re
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing

###
### FlexibleFunctors.jl
Expand All @@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield)
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
return func, re
end

end

end

function flexiblefunctorm(T, pfield = :params)
Expand Down
186 changes: 157 additions & 29 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
struct Foo
x
y
end

using Functors: functor

struct Foo; x; y; end
@functor Foo

struct Bar
x
end
struct Bar; x; end
@functor Bar

struct Baz
x
y
z
end
@functor Baz (y,)
struct OneChild3; x; y; z; end
@functor OneChild3 (y,)

struct NoChildren
x
y
end
struct NoChildren2; x; y; end

@static if VERSION >= v"1.6"
@testset "ComposedFunction" begin
Expand All @@ -31,6 +22,10 @@ end
end
end

###
### Basic functionality
###

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand All @@ -53,20 +48,80 @@ end
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
end

@testset "Property list" begin
model = OneChild3(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "cache" begin
shared = [1,2,3]
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
m1f = fmap(float, m1)
@test m1f.x === m1f.y.y.x
@test m1f.x !== m1f.y.x
m1p = fmapstructure(identity, m1; prune = nothing)
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))

# A non-leaf node can also be repeated:
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
@test m2.x === m2.y
m2f = fmap(float, m2)
@test m2f.x.x === m2f.y.x
m2p = fmapstructure(identity, m2; prune = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))

# Repeated isbits types should not automatically be regarded as shared:
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
m3p = fmapstructure(identity, m3; prune = 0)
@test m3p.y.y == 0
@test_broken m3p.y.x == 1:3
end

@testset "functor(typeof(x), y) from @functor" begin
nt1, re1 = functor(Foo, (x=1, y=2, z=3))
@test nt1 == (x = 1, y = 2)
@test re1((x = 10, y = 20)) == Foo(10, 20)
re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug?

nt2, re2 = functor(Foo, (z=33, x=1, y=2))
@test nt2 == (x = 1, y = 2)
@test re2((x = 10, y = 20)) == Foo(10, 20)

@test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y

nt3, re3 = functor(OneChild3, (x=1, y=2, z=3))
@test nt3 == (y = 2,)
@test re3((y = 20,)) == OneChild3(1, 20, 3)
re3(22) # gives OneChild3(1, 22, 3), is that a bug?
end

@testset "functor(typeof(x), y) for Base types" begin
nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3))
@test nt11 == (x = 1, y = 2)
@test re11((x = 10, y = 20)) == (x = 10, y = 20)
re11((y = 22, x = 11))
re11((11, 22)) # passes right through

nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2))
@test nt12 == (x = 1, y = 2)
@test re12((x = 10, y = 20)) == (x = 10, y = 20)

@test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1))
end

###
### Extras
###

@testset "Walk" begin
model = Foo((0, Bar([1, 2, 3])), [4, 5])

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
end

@testset "Property list" begin
model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "fcollect" begin
m1 = [1, 2, 3]
m2 = 1
Expand All @@ -78,7 +133,7 @@ end

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
Expand All @@ -89,6 +144,79 @@ end
@test all(fcollect(m3) .=== [m3, m1, m2])
end

###
### Vararg forms
###

@testset "fmap(f, x, y)" begin
m1 = (x = [1,2], y = 3)
n1 = (x = [4,5], y = 6)
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)

# Reconstruction type comes from the first argument
foo1 = Foo([7,8], 9)
@test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
@test fmap(+, foo1, n1) isa Foo
@test fmap(+, foo1, n1).x == [11, 13]

# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_throws Exception fmap(first∘tuple, m2, n2) # ERROR: type Int64 has no field a
@test_throws Exception fmap(first∘tuple, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
m3 = (x = shared, y = [4,5,6], z = shared)
n3 = (x = shared, y = shared, z = [7,8,9])
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
z3 = fmap(+, m3, n3)
@test z3.x === z3.z

# Pruning of duplicates:
@test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)

# More than two arguments:
z4 = fmap(+, m3, n3, m3, n3)
@test z4 == fmap(x -> 2x, z3)
@test z4.x === z4.z

@test fmap(+, foo1, m1, n1) isa Foo
@static if VERSION >= v"1.6" # fails on Julia 1.0
@test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
end
end

@static if VERSION >= v"1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
@testset "old test update.jl" begin
struct M{F,T,S}
σ::F
W::T
b::S
end

@functor M

(m::M)(x) = m.σ.(m.W * x .+ m.b)

m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
x = ones(Float32, 4, 2)
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
m̂ = Functors.fmap(m, m̄) do x, y
isnothing(x) && return y
isnothing(y) && return x
x .- 0.1f0 .* y
end

@test m̂.W ≈ fill(0.8f0, size(m.W))
@test m̂.b ≈ fill(-0.2f0, size(m.b))
end
end # VERSION

###
### FlexibleFunctors.jl
###

struct FFoo
x
y
Expand All @@ -102,13 +230,13 @@ struct FBar
end
@flexiblefunctor FBar p

struct FBaz
struct FOneChild4
x
y
z
p
end
@flexiblefunctor FBaz p
@flexiblefunctor FOneChild4 p

@testset "Flexible Nested" begin
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
Expand All @@ -132,7 +260,7 @@ end
end

@testset "Flexible Property list" begin
model = FBaz(1, 2, 3, (:x, :z))
model = FOneChild4(1, 2, 3, (:x, :z))
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (2, 2, 6)
Expand All @@ -147,7 +275,7 @@ end
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Zygote

include("basics.jl")
include("base.jl")
include("update.jl")

if VERSION < v"1.6" # || VERSION > v"1.7-"
@warn "skipping doctests, on Julia $VERSION"
Expand Down
Loading