Skip to content

Make fmap(f, x, y) useful #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 9, 2022
36 changes: 11 additions & 25 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)

functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple}, x) = x, y -> y
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity

functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
Expand Down Expand Up @@ -43,12 +43,11 @@ function _default_walk(f, x)
re(map(f, func))
end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y
struct NoKeyword end

return y
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use keywords like this, this package cannot claim to support Julia 1.0. At present it is only tested on 1.5+. Maybe we should just move to 1.6?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have on CI, might as well make it official.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what repo I was looking at, because I just checked back and it's 1.5 still. Do we need to cut a breaking release for minimum version bumps like this again?

Copy link
Member Author

@mcabbott mcabbott Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 1.0 here: https://github.com/FluxML/Functors.jl/blob/master/Project.toml#L7

Tests do in fact pass on 1.0, for Functors v0.2.7, despite the lack of CI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, people say not to do this on a patch release, because it closes the door on a bugfix-for-1.0 release. Do we care?

We may also need a breaking release for #33, and to make the cache not used on isbits arguments. We could gang these together.

end

###
Expand All @@ -74,27 +73,16 @@ end
### Vararg forms
###

function fmap(f, x, dx...; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
end

function functor_tuple(f, x::Tuple, dx::Tuple)
map(x, dx) do x, x̄
_default_walk(f, x, x̄)
end
end
functor_tuple(f, x, dx) = f(x, dx)
functor_tuple(f, x, ::Nothing) = x

function _default_walk(f, x, dx)
function _default_walk(f, x, ys...)
func, re = functor(x)
map(func, dx) do x, x̄
# functor_tuple(f, x, x̄)
f(x, x̄)
end |> re
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing

###
### FlexibleFunctors.jl
Expand All @@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield)
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
return func, re
end

end

end

function flexiblefunctorm(T, pfield = :params)
Expand Down
182 changes: 153 additions & 29 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
struct Foo
x
y
end

using Functors: functor

struct Foo; x; y; end
@functor Foo

struct Bar
x
end
struct Bar; x; end
@functor Bar

struct Baz
x
y
z
end
@functor Baz (y,)
struct OneChild3; x; y; z; end
@functor OneChild3 (y,)

struct NoChildren
x
y
end
struct NoChildren2; x; y; end

@static if VERSION >= v"1.6"
@testset "ComposedFunction" begin
Expand All @@ -31,6 +22,10 @@ end
end
end

###
### Basic functionality
###

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand All @@ -53,20 +48,80 @@ end
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
end

@testset "Property list" begin
model = OneChild3(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "cache" begin
shared = [1,2,3]
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
m1f = fmap(float, m1)
@test m1f.x === m1f.y.y.x
@test m1f.x !== m1f.y.x
m1p = fmapstructure(identity, m1; prune = nothing)
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))

# A non-leaf node can also be repeated:
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
@test m2.x === m2.y
m2f = fmap(float, m2)
@test m2f.x.x === m2f.y.x
m2p = fmapstructure(identity, m2; prune = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))

# Repeated isbits types should not automatically be regarded as shared:
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
m3p = fmapstructure(identity, m3; prune = 0)
@test m3p.y.y == 0
@test_broken m3p.y.x == 1:3
end

@testset "functor(typeof(x), y) from @functor" begin
nt1, re1 = functor(Foo, (x=1, y=2, z=3))
@test nt1 == (x = 1, y = 2)
@test re1((x = 10, y = 20)) == Foo(10, 20)
re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug?

nt2, re2 = functor(Foo, (z=33, x=1, y=2))
@test nt2 == (x = 1, y = 2)
@test re2((x = 10, y = 20)) == Foo(10, 20)

@test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y

nt3, re3 = functor(OneChild3, (x=1, y=2, z=3))
@test nt3 == (y = 2,)
@test re3((y = 20,)) == OneChild3(1, 20, 3)
re3(22) # gives OneChild3(1, 22, 3), is that a bug?
end

@testset "functor(typeof(x), y) for Base types" begin
nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3))
@test nt11 == (x = 1, y = 2)
@test re11((x = 10, y = 20)) == (x = 10, y = 20)
re11((y = 22, x = 11))
re11((11, 22)) # passes right through

nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2))
@test nt12 == (x = 1, y = 2)
@test re12((x = 10, y = 20)) == (x = 10, y = 20)

@test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1))
end

###
### Extras
###

@testset "Walk" begin
model = Foo((0, Bar([1, 2, 3])), [4, 5])

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
end

@testset "Property list" begin
model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "fcollect" begin
m1 = [1, 2, 3]
m2 = 1
Expand All @@ -78,7 +133,7 @@ end

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
Expand All @@ -89,6 +144,75 @@ end
@test all(fcollect(m3) .=== [m3, m1, m2])
end

###
### Vararg forms
###

@testset "fmap(f, x, y)" begin
m1 = (x = [1,2], y = 3)
n1 = (x = [4,5], y = 6)
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)

# Reconstruction type comes from the first argument
foo1 = Foo([7,8], 9)
@test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
@test fmap(+, foo1, n1) isa Foo
@test fmap(+, foo1, n1).x == [11, 13]

# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_throws Exception fmap(first∘tuple, m2, n2) # ERROR: type Int64 has no field a
@test_throws Exception fmap(first∘tuple, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
m3 = (x = shared, y = [4,5,6], z = shared)
n3 = (x = shared, y = shared, z = [7,8,9])
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
z3 = fmap(+, m3, n3)
@test z3.x === z3.z

# Pruning of duplicates:
@test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)

# More than two arguments:
z4 = fmap(+, m3, n3, m3, n3)
@test z4 == fmap(x -> 2x, z3)
@test z4.x === z4.z

@test fmap(+, foo1, m1, n1) isa Foo
@test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
end

@testset "old test update.jl" begin
struct M{F,T,S}
σ::F
W::T
b::S
end

@functor M

(m::M)(x) = m.σ.(m.W * x .+ m.b)

m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
x = ones(Float32, 4, 2)
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
m̂ = Functors.fmap(m, m̄) do x, y
isnothing(x) && return y
isnothing(y) && return x
x .- 0.1f0 .* y
end

@test m̂.W ≈ fill(0.8f0, size(m.W))
@test m̂.b ≈ fill(-0.2f0, size(m.b))
end

###
### FlexibleFunctors.jl
###

struct FFoo
x
y
Expand All @@ -102,13 +226,13 @@ struct FBar
end
@flexiblefunctor FBar p

struct FBaz
struct FOneChild4
x
y
z
p
end
@flexiblefunctor FBaz p
@flexiblefunctor FOneChild4 p

@testset "Flexible Nested" begin
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
Expand All @@ -132,7 +256,7 @@ end
end

@testset "Flexible Property list" begin
model = FBaz(1, 2, 3, (:x, :z))
model = FOneChild4(1, 2, 3, (:x, :z))
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (2, 2, 6)
Expand All @@ -147,7 +271,7 @@ end
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Zygote

include("basics.jl")
include("base.jl")
include("update.jl")

if VERSION < v"1.6" # || VERSION > v"1.7-"
@warn "skipping doctests, on Julia $VERSION"
Expand Down
23 changes: 0 additions & 23 deletions test/update.jl

This file was deleted.