@@ -226,7 +226,9 @@ function _norm_layer_forward(
226226 l, x:: AbstractArray{T, N} ; reduce_dims, affine_shape,
227227) where {T, N}
228228 if ! _isactive (l, x) && l. track_stats # testmode with tracked stats
229- stats_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
229+ stats_shape = ChainRulesCore. ignore_derivatives () do
230+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
231+ end
230232 μ = reshape (l. μ, stats_shape)
231233 σ² = reshape (l. σ², stats_shape)
232234 else # trainmode or testmode without tracked stats
@@ -347,7 +349,9 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
347349function (BN:: BatchNorm )(x:: AbstractArray{T,N} ) where {T,N}
348350 _size_check (BN, x, N- 1 => BN. chs)
349351 reduce_dims = [1 : N- 2 ; N]
350- affine_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
352+ affine_shape = ChainRulesCore. ignore_derivatives () do
353+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
354+ end
351355 return _norm_layer_forward (BN, x; reduce_dims, affine_shape)
352356end
353357
@@ -439,7 +443,9 @@ trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
439443function (l:: InstanceNorm )(x:: AbstractArray{T,N} ) where {T,N}
440444 _size_check (l, x, N- 1 => l. chs)
441445 reduce_dims = 1 : N- 2
442- affine_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
446+ affine_shape = ChainRulesCore. ignore_derivatives () do
447+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
448+ end
443449 return _norm_layer_forward (l, x; reduce_dims, affine_shape)
444450end
445451
@@ -456,10 +462,10 @@ end
456462
457463"""
458464 GroupNorm(channels::Int, G::Int, λ = identity;
459- initβ = zeros32,
465+ initβ = zeros32,
460466 initγ = ones32,
461- affine = true,
462- eps = 1f-5,
467+ affine = true,
468+ eps = 1f-5,
463469 momentum = 0.1f0)
464470
465471[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
@@ -538,12 +544,14 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
538544end
539545
540546function (gn:: GroupNorm )(x:: AbstractArray )
541- _size_check (gn, x, ndims (x)- 1 => gn. chs)
547+ _size_check (gn, x, ndims (x)- 1 => gn. chs)
542548 sz = size (x)
543549 x2 = reshape (x, sz[1 : end - 2 ]. .. , sz[end - 1 ]÷ gn. G, gn. G, sz[end ])
544550 N = ndims (x2) # == ndims(x)+1
545551 reduce_dims = 1 : N- 2
546- affine_shape = ntuple (i -> i ∈ (N- 1 , N- 2 ) ? size (x2, i) : 1 , N)
552+ affine_shape = ChainRulesCore. ignore_derivatives () do
553+ ntuple (i -> i ∈ (N- 1 , N- 2 ) ? size (x2, i) : 1 , N)
554+ end
547555 x3 = _norm_layer_forward (gn, x2; reduce_dims, affine_shape)
548556 return reshape (x3, sz)
549557end
0 commit comments