Skip to content

Allow user to change the channel axis for BatchNorm function and the likes #1666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.12.4
* Implemented [axis option for normalisation functions](https://github.com/FluxML/Flux.jl/issues/1664).

## v0.12.4
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
based on `NNlib.gather` and `NNlib.scatter`.
Expand Down
55 changes: 36 additions & 19 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ end
# reduce_dims=[1,...,N-2,N] for BatchNorm
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
isnothing(l.dim) ? dim = N-1 : dim = l.dim
if !_isactive(l) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
stats_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
Expand All @@ -180,8 +181,8 @@ function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape
Zygote.ignore() do
mtm = l.momentum
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N))
μnew = vec(dim+1 ∈ reduce_dims ? μ : mean(μ, dims=dim+1))
σ²new = vec(dim+1 ∈ reduce_dims ? σ² : mean(σ², dims=dim+1))
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
end
Expand All @@ -198,6 +199,7 @@ end

"""
BatchNorm(channels::Integer, λ=identity;
dim = nothing,
initβ=zeros32, initγ=ones32,
ϵ=1f-5, momentum= 0.1f0)

Expand All @@ -206,7 +208,7 @@ end

Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.
it's the usual channel dimension. Use `dim=dim` to change the channel dimension to `dim`.

`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
input slice and normalises the input accordingly.
Expand Down Expand Up @@ -243,9 +245,11 @@ mutable struct BatchNorm{F,V,N,W}
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
dim::Union{Int, Nothing} # channel dimension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never be nothing, but use N - 1 as default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you know N - 1 a priori though? The question is what is an appropriate sentinel value. Perhaps a negative offset from the end, since by default channels are at dim end - 1?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right so, it will need to be defined at runtime.

end

function BatchNorm(chs::Int, λ=identity;
dim = nothing,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
ϵ=1f-5, momentum=0.1f0)
Expand All @@ -258,17 +262,19 @@ function BatchNorm(chs::Int, λ=identity;
return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
nothing, chs, dim)
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
reduce_dims = [1:N-2; N]
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
isnothing(BN.dim) ? dim = N-1 : dim = BN.dim
@assert dim < N
@assert size(x, dim) == BN.chs
reduce_dims = [1:dim-1;dim+1:N]
affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N)
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
end

Expand All @@ -294,6 +300,7 @@ end

Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
Use `dim=dim` to change the channel dimension to `dim`.

`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
input slice and normalises the input accordingly.
Expand All @@ -319,9 +326,11 @@ mutable struct InstanceNorm{F,V,N,W}
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
dim::Union{Int, Nothing} # channel dimension
end

function InstanceNorm(chs::Int, λ=identity;
dim = nothing,
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
Expand All @@ -334,18 +343,20 @@ function InstanceNorm(chs::Int, λ=identity;
return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
nothing, chs, dim)
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == l.chs
N = ndims(x)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
@assert N > 2
isnothing(l.dim) ? dim = N-1 : dim = l.dim
@assert dim < N
@assert size(x, dim) == l.chs
reduce_dims = [1:dim-1;dim+2:N];
affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N)
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
end

Expand All @@ -361,6 +372,7 @@ end

"""
GroupNorm(channels::Integer, G::Integer, λ=identity;
dim = nothing
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
Expand All @@ -377,6 +389,7 @@ The number of channels must be an integer multiple of the number of groups.

Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
Use `dim=dim` to change the channel dimension to `dim`.

If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
Expand All @@ -397,12 +410,14 @@ mutable struct GroupNorm{F,V,N,W}
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
dim::Union{Int, Nothing} # channel dimension
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()

function GroupNorm(chs::Int, G::Int, λ=identity;
dim = nothing,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
Expand All @@ -419,18 +434,20 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
nothing, chs, dim)
end

function (gn::GroupNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == gn.chs
N = ndims(x)
@assert N > 2
isnothing(gn.dim) ? dim = N-1 : dim = gn.dim
@assert dim < N
@assert size(x, dim) == gn.chs
sz = size(x)
x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N])
x = reshape(x, sz[1:dim-1]..., sz[dim]÷gn.G, gn.G, sz[dim+1:N]...)
N = ndims(x)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N)
reduce_dims = [1:dim;dim+3:N];
affine_shape = ntuple(i -> i ∈ (dim+1, dim) ? size(x, i) : 1, N)
x = _norm_layer_forward(gn, x; reduce_dims, affine_shape)
return reshape(x, sz)
end
Expand Down