Skip to content

zygote broadcast type stability #1301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
maartenvd opened this issue Sep 1, 2022 · 5 comments · Fixed by #1302
Closed

zygote broadcast type stability #1301

maartenvd opened this issue Sep 1, 2022 · 5 comments · Fixed by #1302

Comments

@maartenvd
Copy link
Contributor

I have a package that defines a simple utility Array type : https://github.com/maartenvd/MPSKit.jl/blob/diskarray/src/utility/periodicarray.jl
Zygote changes types when taking the derivative, which later on makes my backward rules fail. Here is a minimal example:

julia> f_add(x) = x + 3;
julia> function myfun(x)
       y = f_add.(x);
       @show typeof(y)
       norm(y)
       end
myfun (generic function with 1 method)
julia> myfun'(PeriodicArray(rand(5,5)));
typeof(y) = Matrix{Float64}

julia> myfun(PeriodicArray(rand(5,5)));
typeof(y) = PeriodicArray{Float64, 2}

This type change causes failures, as it then calls rrule with a tangent type of PeriodicArray, but a (wrong) primal type of Matrix

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2022

What it's doing is this:

out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = map(x -> x.value, out)

julia> out = Zygote.dual_function(f_add).(PeriodicArray(rand(2,3)))
2×3 PeriodicArray{ForwardDiff.Dual{Nothing, Float64, 1}, 2}:
 Dual{Nothing}(3.18449,1.0)  Dual{Nothing}(3.32056,1.0)  Dual{Nothing}(3.9397,1.0)
 Dual{Nothing}(3.20766,1.0)  Dual{Nothing}(3.20073,1.0)  Dual{Nothing}(3.19711,1.0)

julia> y = map(x -> x.value, out)
2×3 Matrix{Float64}:
 3.18449  3.32056  3.9397
 3.20766  3.20073  3.19711

I am not sure why the second is map , since calling broadcast(x -> x.value, out) would preserve this type. Want to make a PR and see what breaks?

FWIW, ChainRules uses https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/broadcast.jl#L108-L114 which does not preserve such types, would be nice if it did too:

julia> myfun(PeriodicArray(rand(5,5)));
typeof(y) = PeriodicArray{Float64, 2}

julia> gradient(myfun, PeriodicArray(rand(5,5)));
┌ Debug: split broadcasting forwards
│   frule_fun = frule_via_ad (generic function with 1 method)
│   f = f_add (generic function with 1 method)
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:107
typeof(y) = Matrix{Float64}

julia> ys, ydots = ChainRules.unzip_broadcast(PeriodicArray(rand(2,3))) do a
           frule_via_ad(Diffractor.DiffractorRuleConfig(), (NoTangent(), one(a)), f_add, a)
       end;

julia> ys
2×3 Matrix{Float64}:
 3.95438  3.31328  3.83527
 3.59566  3.02989  3.05676

julia> ydots
2×3 Matrix{Float64}:
 1.0  1.0  1.0
 1.0  1.0  1.0

@maartenvd
Copy link
Contributor Author

I'll try to make a pull request, I will also look a bit at the ChainRules case

@maartenvd
Copy link
Contributor Author

I cannot reproduce the chainrules case:

julia> gradient(myfun, PeriodicArray(rand(5,5)));
typeof(y) = PeriodicArray{Float64, 2}
ERROR: BoundsError: attempt to access Tuple{Nothing, Nothing} at index [3]

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2022

This was a bit obscure, sorry, but mine was using JuliaDiff/Diffractor.jl#89 . How did you get this error though? It might be the cause of some other problems, perhaps. (And can you post the stacktrace?)

@maartenvd
Copy link
Contributor Author

that error disappeared after restarting julia, so I'm just going to quietly ignore it, and hope it's not because I wrote a funky rule somewhere else. I can now reproduce the Diffractor thing :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants