Source code for nnabla_nas.module.parameter

# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nnabla as nn
import numpy as np
from nnabla.parameter import set_parameter


[docs] class Parameter(nn.Variable): r"""Parameter is a Variable. A kind of Variable that is to be considered a module parameter. Parameters are :class:`~nnabla.Variable` subclasses, that have a very special property when used with :class:`Module` s - when they're assigned as Module attributes they are automatically added to the list of its parameters. Args: shape (tuple of int): The shape of Parameter. need_grad (bool, optional): If the parameter requires gradient. Defaults to True. initializer (:obj:`nnabla.initializer.BaseInitializer` or :obj:`numpy.ndarray`): An initialization function to be applied to the parameter. :obj:`numpy.ndarray` can also be given to initialize parameters from numpy array data. Defaults to None. """ def __init__(self, *args, **kwargs): super(Parameter, self).__init__() if 'scope' in kwargs.keys(): with nn.parameter_scope(kwargs['scope']): set_parameter(self.__repr__(), self) else: set_parameter(self.__repr__(), self) def __new__(cls, shape, need_grad=True, initializer=None, scope=''): assert shape is not None obj = super().__new__(cls, shape, need_grad) if initializer is None: obj.data.zero() elif isinstance(initializer, np.ndarray): assert tuple(shape) == initializer.shape obj.d = initializer elif callable(initializer): obj.d = initializer(shape=list(map(int, shape))) else: raise ValueError( '`initializer` must be either the: obj: `numpy.ndarray`' 'or an instance inherited from' '`nnabla.initializer.BaseInitializer`.' ) obj.grad.zero() return obj def __repr__(self): return (f'<Parameter({self.shape}, need_grad={self.need_grad})' f' at {hex(id(self))}>')