# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnabla as nn
import nnabla.functions as F
from nnabla.initializer import ConstantInitializer
from .module import Module
from .parameter import Parameter
from ..utils.helper import AverageMeter
[docs]
class BatchNormalization(Module):
r"""Batch normalization layer.
Args:
n_features (int): Number of dimentional features.
n_dims (int): Number of dimensions.
axes (:obj:`tuple` of :obj:`int`): Mean and variance for each
element in ``axes`` are calculated using elements on the
rest axes. For example, if an input is 4 dimensions,
and ``axes`` is ``[1]``, batch mean is calculated as
``np.mean(inp.d, axis=(0, 2, 3), keepdims=True)``
(using numpy expression as an example).
decay_rate (float, optional): Decay rate of running mean and
variance. Defaults to 0.9.
eps (float, optional): Tiny value to avoid zero division by std.
Defaults to 1e-5.
output_stat (bool, optional): Output batch mean and variance.
Defaults to `False`.
fix_parameters (bool): When set to `True`, the beta and gamma will
not be updated.
param_init (dict):
Parameter initializers can be set with a dict. A key of the
dict must be ``'beta'``, ``'gamma'``, ``'mean'`` or ``'var'``.
A value of the dict must be an :obj:`~nnabla.initializer.
Initializer` or a :obj:`numpy.ndarray`.
E.g.::
{
'beta': ConstantIntializer(0),
'gamma': np.ones(gamma_shape) * 2
}
name (string): the name of this module
Returns:
:class:`~nnabla.Variable`: N-D array.
References:
Ioffe and Szegedy, Batch Normalization: Accelerating Deep
Network Training by Reducing Internal Covariate Shift.
https://arxiv.org/abs/1502.03167
"""
def __init__(self, n_features, n_dims, axes=[1], decay_rate=0.9, eps=1e-5,
output_stat=False, fix_parameters=False, param_init=None,
name=''):
Module.__init__(self, name=name)
self._scope_name = f'<batchnorm at {hex(id(self))}>'
assert len(axes) == 1
shape_stat = [1 for _ in range(n_dims)]
shape_stat[axes[0]] = n_features
if param_init is None:
param_init = {}
beta_init = param_init.get('beta', ConstantInitializer(0))
gamma_init = param_init.get('gamma', ConstantInitializer(1))
mean_init = param_init.get('mean', ConstantInitializer(0))
var_init = param_init.get('var', ConstantInitializer(1))
if fix_parameters:
self._beta = nn.Variable.from_numpy_array(
beta_init(shape_stat))
self._gamma = nn.Variable.from_numpy_array(
gamma_init(shape_stat))
else:
self._beta = Parameter(shape_stat, initializer=beta_init,
scope=self._scope_name)
self._gamma = Parameter(shape_stat, initializer=gamma_init,
scope=self._scope_name)
self._mean = Parameter(shape_stat, need_grad=False,
initializer=mean_init,
scope=self._scope_name)
self._var = Parameter(shape_stat, need_grad=False,
initializer=var_init,
scope=self._scope_name)
self._axes = axes
self._decay_rate = decay_rate
self._eps = eps
self._n_features = n_features
self._fix_parameters = fix_parameters
self._output_stat = output_stat
# for set running statistivs
self.set_running_statistics = False
self.mean_est = AverageMeter(self._scope_name)
self.var_est = AverageMeter(self._scope_name)
[docs]
def call(self, input):
if self.set_running_statistics:
"""
Note: this code block is verified with only
once-for-all algorithm so far.
"""
batch_mean = F.mean(input, axis=(0, 2, 3), keepdims=True)
batch_var = F.mean(input ** 2, axis=(0, 2, 3),
keepdims=True) - batch_mean ** 2
self.mean_est.update(batch_mean.d, input.shape[0])
self.var_est.update(batch_var.d, input.shape[0])
_feature_dim = batch_mean.shape[1]
return F.batch_normalization(
input, self._beta[:, :_feature_dim, :, :], self._gamma[:, :_feature_dim, :, :],
batch_mean, batch_var, decay_rate=self._decay_rate, eps=self._eps, batch_stat=False
)
else:
return F.batch_normalization(input, self._beta, self._gamma,
self._mean, self._var, self._axes,
self._decay_rate, self._eps,
self.training, self._output_stat)