Skip to content

Allow user to change the channel axis for BatchNorm function and the likes #1666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

vboussange
Copy link

Proposition for issue #1664

This PR allows the user to customize the channel axis for normalisation functions (BatchNorm, GroupNorm and InstanceNorm).

Example

using Flux
channel_size = 3
channel_axis = 1
BN = BatchNorm(channel_size, dim = channel_axis)
x = randn(channel_size, 10, 10)
BN(x)

@@ -243,9 +245,11 @@ mutable struct BatchNorm{F,V,N,W}
track_stats::Bool
active::Union{Bool, Nothing}
chs::Int # number of channels
dim::Union{Int, Nothing} # channel dimension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never be nothing, but use N - 1 as default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you know N - 1 a priori though? The question is what is an appropriate sentinel value. Perhaps a negative offset from the end, since by default channels are at dim end - 1?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right so, it will need to be defined at runtime.

@DhairyaLGandhi
Copy link
Member

How well is cudnn able to handle this?

@ToucheSir
Copy link
Member

Looping back to answer the cuDNN question: it supports one other configuration for channel dim via CUDNN_TENSOR_NHWC. I think adding that to NNlibCUDA at https://github.com/FluxML/NNlibCUDA.jl/blob/96a334633ef3a3707c85fc1754c2c7eb8849db4e/src/cudnn/batchnorm.jl#L27 would be a good first step for getting this going again. Here are a couple example implementation PRs from MXNet and PyTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants