From 9de594dd7556d086379f6ad0bf3d3823cbdf2701 Mon Sep 17 00:00:00 2001 From: Victor Date: Wed, 14 Jul 2021 19:18:35 +0200 Subject: [PATCH 1/6] adding dim to BatchNorm --- src/layers/normalise.jl | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b4f7e1f134..9c63632324 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -167,9 +167,11 @@ end # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} +function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape, dim=nothing) where {T, N} + # todo:change + isnothing(dim) ? dim = N-1 : nothing if !_isactive(l) && l.track_stats # testmode with tracked stats - stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + stats_shape = ntuple(i -> i == dim ? size(x,dim) : 1, N) μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) else # trainmode or testmode without tracked stats @@ -198,14 +200,16 @@ end """ BatchNorm(channels::Integer, λ=identity; + dim = nothing, initβ=zeros32, initγ=ones32, ϵ=1f-5, momentum= 0.1f0) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N` dimensions, call the `N-1`th the channel dimension. For -a batch of feature vectors this is just the data dimension, for `WHCN` images +Given an array with `N` dimensions, call the `N-1`th the channel dimension. +If `dim` specificied, call `dim` the channel dimenesion. For a batch of feature vectors +this is just the data dimension, for `WHCN` images it's the usual channel dimension. `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` @@ -243,9 +247,11 @@ mutable struct BatchNorm{F,V,N,W} track_stats::Bool active::Union{Bool, Nothing} chs::Int # number of channels + dim::Int # channel dimension end function BatchNorm(chs::Int, λ=identity; + dim = nothing, initβ=zeros32, initγ=ones32, affine=true, track_stats=true, ϵ=1f-5, momentum=0.1f0) @@ -258,18 +264,19 @@ function BatchNorm(chs::Int, λ=identity; return BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, affine, track_stats, - nothing, chs) + nothing, chs, dim) end @functor BatchNorm trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : () function (BN::BatchNorm)(x) - @assert size(x, ndims(x)-1) == BN.chs N = ndims(x) - reduce_dims = [1:N-2; N] - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(BN, x; reduce_dims, affine_shape) + isnothing(BN.dim) ? dim = N-1 : dim = BN.dim + @assert size(x, dim) == BN.chs + reduce_dims = [1:dim-1;dim+1:N] + affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N) + return _norm_layer_forward(BN, x; reduce_dims, affine_shape, dim) end testmode!(m::BatchNorm, mode=true) = @@ -346,7 +353,7 @@ function (l::InstanceNorm)(x) N = ndims(x) reduce_dims = 1:N-2 affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(l, x; reduce_dims, affine_shape) + return _norm_layer_forward(l, x; reduce_dims, affine_shape, N-1) end testmode!(m::InstanceNorm, mode=true) = From a1cb5e25d15b306a88af4a78a87fb42c5b257624 Mon Sep 17 00:00:00 2001 From: Victor Date: Thu, 15 Jul 2021 10:39:52 +0200 Subject: [PATCH 2/6] fix pb with InstanceNorm --- src/layers/normalise.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9c63632324..9a2592cfe3 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -276,7 +276,7 @@ function (BN::BatchNorm)(x) @assert size(x, dim) == BN.chs reduce_dims = [1:dim-1;dim+1:N] affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N) - return _norm_layer_forward(BN, x; reduce_dims, affine_shape, dim) + return _norm_layer_forward(BN, x; reduce_dims, affine_shape, dim=dim) end testmode!(m::BatchNorm, mode=true) = @@ -353,7 +353,7 @@ function (l::InstanceNorm)(x) N = ndims(x) reduce_dims = 1:N-2 affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(l, x; reduce_dims, affine_shape, N-1) + return _norm_layer_forward(l, x; reduce_dims, affine_shape, dim=N-1) end testmode!(m::InstanceNorm, mode=true) = From 7f7751c32bdfaf583f02518147684ae1e21b62ed Mon Sep 17 00:00:00 2001 From: Victor Date: Fri, 16 Jul 2021 12:36:30 +0200 Subject: [PATCH 3/6] extended channel axis option to InstanceNorm and GroupNorm --- src/layers/normalise.jl | 43 ++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9a2592cfe3..40717e2397 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -167,9 +167,8 @@ end # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape, dim=nothing) where {T, N} - # todo:change - isnothing(dim) ? dim = N-1 : nothing +function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} + isnothing(l.dim) ? dim = N-1 : dim = l.dim if !_isactive(l) && l.track_stats # testmode with tracked stats stats_shape = ntuple(i -> i == dim ? size(x,dim) : 1, N) μ = reshape(l.μ, stats_shape) @@ -247,7 +246,7 @@ mutable struct BatchNorm{F,V,N,W} track_stats::Bool active::Union{Bool, Nothing} chs::Int # number of channels - dim::Int # channel dimension + dim::Union{Int, Nothing} # channel dimension end function BatchNorm(chs::Int, λ=identity; @@ -273,10 +272,11 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : () function (BN::BatchNorm)(x) N = ndims(x) isnothing(BN.dim) ? dim = N-1 : dim = BN.dim + @assert dim < N @assert size(x, dim) == BN.chs reduce_dims = [1:dim-1;dim+1:N] affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N) - return _norm_layer_forward(BN, x; reduce_dims, affine_shape, dim=dim) + return _norm_layer_forward(BN, x; reduce_dims, affine_shape) end testmode!(m::BatchNorm, mode=true) = @@ -326,9 +326,11 @@ mutable struct InstanceNorm{F,V,N,W} track_stats::Bool active::Union{Bool, Nothing} chs::Int # number of channels + dim::Union{Int, Nothing} # channel dimension end function InstanceNorm(chs::Int, λ=identity; + dim = nothing, initβ=zeros32, initγ=ones32, affine=false, track_stats=false, ϵ=1f-5, momentum=0.1f0) @@ -341,19 +343,21 @@ function InstanceNorm(chs::Int, λ=identity; return InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, affine, track_stats, - nothing, chs) + nothing, chs, dim) end @functor InstanceNorm trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : () function (l::InstanceNorm)(x) - @assert ndims(x) > 2 - @assert size(x, ndims(x)-1) == l.chs N = ndims(x) - reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(l, x; reduce_dims, affine_shape, dim=N-1) + @assert N > 2 + isnothing(l.dim) ? dim = N-1 : dim = l.dim + @assert dim < N + @assert size(x, dim) == l.chs + reduce_dims = [1:dim-1;dim+2:N]; + affine_shape = ntuple(i -> i == dim ? size(x,dim) : 1, N) + return _norm_layer_forward(l, x; reduce_dims, affine_shape) end testmode!(m::InstanceNorm, mode=true) = @@ -404,12 +408,14 @@ mutable struct GroupNorm{F,V,N,W} track_stats::Bool active::Union{Bool, Nothing} chs::Int # number of channels + dim::Union{Int, Nothing} # channel dimension end @functor GroupNorm trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () function GroupNorm(chs::Int, G::Int, λ=identity; + dim = nothing, initβ=zeros32, initγ=ones32, affine=true, track_stats=false, ϵ=1f-5, momentum=0.1f0) @@ -426,18 +432,19 @@ function GroupNorm(chs::Int, G::Int, λ=identity; μ, σ², ϵ, momentum, affine, track_stats, - nothing, chs) + nothing, chs, dim) end function (gn::GroupNorm)(x) - @assert ndims(x) > 2 - @assert size(x, ndims(x)-1) == gn.chs N = ndims(x) + @assert N > 2 + isnothing(gn.dim) ? dim = N-1 : dim = gn.dim + @assert dim < N + @assert size(x, dim) == gn.chs sz = size(x) - x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N]) - N = ndims(x) - reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N) + x = reshape(x, sz[1:dim-1]..., sz[dim]÷gn.G, gn.G, sz[N], sz[dim+2:N]...,) + reduce_dims = [1:dim-1;dim+2:N]; + affine_shape = ntuple(i -> i ∈ (dim, dim-1) ? size(x, i) : 1, N) x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) return reshape(x, sz) end From 3a772c8396897573213904c104e473e17b91b7e3 Mon Sep 17 00:00:00 2001 From: Victor Date: Fri, 16 Jul 2021 13:09:14 +0200 Subject: [PATCH 4/6] quick ix --- src/layers/normalise.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 40717e2397..1db4290696 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -170,7 +170,7 @@ end function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} isnothing(l.dim) ? dim = N-1 : dim = l.dim if !_isactive(l) && l.track_stats # testmode with tracked stats - stats_shape = ntuple(i -> i == dim ? size(x,dim) : 1, N) + stats_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N) μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) else # trainmode or testmode without tracked stats @@ -356,7 +356,7 @@ function (l::InstanceNorm)(x) @assert dim < N @assert size(x, dim) == l.chs reduce_dims = [1:dim-1;dim+2:N]; - affine_shape = ntuple(i -> i == dim ? size(x,dim) : 1, N) + affine_shape = ntuple(i -> i == dim ? size(x, dim) : 1, N) return _norm_layer_forward(l, x; reduce_dims, affine_shape) end @@ -442,7 +442,7 @@ function (gn::GroupNorm)(x) @assert dim < N @assert size(x, dim) == gn.chs sz = size(x) - x = reshape(x, sz[1:dim-1]..., sz[dim]÷gn.G, gn.G, sz[N], sz[dim+2:N]...,) + x = reshape(x, sz[1:dim-1]..., sz[dim]÷gn.G, gn.G, sz[dim+1:N]...) reduce_dims = [1:dim-1;dim+2:N]; affine_shape = ntuple(i -> i ∈ (dim, dim-1) ? size(x, i) : 1, N) x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) From a4d01073195295bae4b4ecb6b3c36b4d32db031c Mon Sep 17 00:00:00 2001 From: Victor Date: Fri, 16 Jul 2021 13:47:38 +0200 Subject: [PATCH 5/6] tests passing, ready to PR --- src/layers/normalise.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1db4290696..6078ccdad3 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -181,8 +181,8 @@ function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape Zygote.ignore() do mtm = l.momentum m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var - μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) - σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) + μnew = vec(dim+1 ∈ reduce_dims ? μ : mean(μ, dims=dim+1)) + σ²new = vec(dim+1 ∈ reduce_dims ? σ² : mean(σ², dims=dim+1)) l.μ = (1-mtm) .* l.μ .+ mtm .* μnew l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new end @@ -206,10 +206,9 @@ end [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N` dimensions, call the `N-1`th the channel dimension. -If `dim` specificied, call `dim` the channel dimenesion. For a batch of feature vectors -this is just the data dimension, for `WHCN` images -it's the usual channel dimension. +Given an array with `N` dimensions, call the `N-1`th the channel dimension. For +a batch of feature vectors this is just the data dimension, for `WHCN` images +it's the usual channel dimension. Use `dim=dim` to change the channel dimension to `dim`. `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` input slice and normalises the input accordingly. @@ -301,6 +300,7 @@ end Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. +Use `dim=dim` to change the channel dimension to `dim`. `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` input slice and normalises the input accordingly. @@ -372,6 +372,7 @@ end """ GroupNorm(channels::Integer, G::Integer, λ=identity; + dim = nothing initβ=zeros32, initγ=ones32, affine=true, track_stats=false, ϵ=1f-5, momentum=0.1f0) @@ -388,6 +389,7 @@ The number of channels must be an integer multiple of the number of groups. Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. +Use `dim=dim` to change the channel dimension to `dim`. If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. @@ -443,8 +445,9 @@ function (gn::GroupNorm)(x) @assert size(x, dim) == gn.chs sz = size(x) x = reshape(x, sz[1:dim-1]..., sz[dim]÷gn.G, gn.G, sz[dim+1:N]...) - reduce_dims = [1:dim-1;dim+2:N]; - affine_shape = ntuple(i -> i ∈ (dim, dim-1) ? size(x, i) : 1, N) + N = ndims(x) + reduce_dims = [1:dim;dim+3:N]; + affine_shape = ntuple(i -> i ∈ (dim+1, dim) ? size(x, i) : 1, N) x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) return reshape(x, sz) end From 6113c343974a8f6d55cd7fcf06cef0092685e85c Mon Sep 17 00:00:00 2001 From: Victor Date: Fri, 16 Jul 2021 14:03:04 +0200 Subject: [PATCH 6/6] added changes --- NEWS.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NEWS.md b/NEWS.md index 65a8458680..7d0131cadb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # Flux Release Notes +## v0.12.4 +* Implemented [axis option for normalisation functions](https://github.com/FluxML/Flux.jl/issues/1664). + ## v0.12.4 * Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516) based on `NNlib.gather` and `NNlib.scatter`.