# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from collections import OrderedDict
from .module import Module
from .parameter import Parameter
def _get_abs_string_index(obj, idx):
"""Get the absolute index for the list of modules"""
idx = int(operator.index(idx))
if not (-len(obj) <= idx < len(obj)):
raise IndexError('index {} is out of range'.format(idx))
if idx < 0:
idx += len(obj)
return str(idx)
[docs]
class ModuleList(Module):
r"""Hold submodules in a list. This implementation mainly follows
the Pytorch implementation.
Args:
modules (iterable, optional): An iterable of modules to add.
"""
def __init__(self, modules=None):
if modules is not None:
self += modules
[docs]
def append(self, module):
r"""Appends a given module to the end of the list.
Args:
module (~nnabla_nas.module.module.Module): A module to append.
"""
if not isinstance(module, Module):
ValueError(f'{module} is not an instance of Module.')
setattr(self, str(len(self)), module)
return self
[docs]
def extend(self, modules):
r"""Appends modules from a Python iterable to the end of the list.
Args:
modules (iterable): An iterable of modules to append.
"""
for module in modules:
self.append(module)
return self
[docs]
def insert(self, index, module):
r"""Insert a given module before a given index in the list.
Args:
index (int): An index to insert.
module (~nnabla_nas.module.module.Module): A module to insert.
"""
if not isinstance(module, Module):
ValueError(f'{module} is not an instance of Module.')
for i in range(len(self), index, -1):
self.modules[str(i)] = self.modules[str(i - 1)]
self.modules[str(index)] = module
return self
def __getitem__(self, index):
if isinstance(index, slice):
return self.__class__(list(self.modules.values())[index])
index = _get_abs_string_index(self, index)
return self.modules[index]
def __setitem__(self, index, module):
if not isinstance(module, Module):
ValueError(f'{module} is not an instance of Module.')
index = _get_abs_string_index(self, index)
self.modules[str(index)] = module
def __delitem__(self, index):
if isinstance(index, slice):
for k in range(len(self.modules))[index]:
delattr(self, str(k))
else:
delattr(self, _get_abs_string_index(self, index))
indices = [str(i) for i in range(len(self.modules))]
self._modules = OrderedDict(list(zip(indices, self.modules.values())))
def __len__(self):
return len(self.modules)
def __iter__(self):
return iter(self.modules.values())
def __iadd__(self, modules):
return self.extend(modules)
[docs]
class ParameterList(Module):
r"""Hold parameters in a list.
Args:
parameters (iterable, optional): An iterable of parameters to add.
"""
def __init__(self, parameters=None):
if parameters is not None:
self += parameters
[docs]
def append(self, parameter):
r"""Appends a given module to the end of the list.
Args:
parameter (Parameter): A parameter to append.
"""
if not isinstance(parameter, Parameter):
ValueError(f'{parameter} is not an instance of Parameter.')
setattr(self, str(len(self)), parameter)
return self
[docs]
def extend(self, parameters):
"""Extends an iterable of parameters to the end of the list.
Args:
parameters (iterable): An iterable of Parameters.
"""
for parameter in parameters:
self.append(parameter)
return self
[docs]
def insert(self, index, parameter):
r"""Insert a given parameter before a given index in the list.
Args:
index (int): An index to insert.
parameter (Parameter): A parameter to insert.
"""
if not isinstance(parameter, Parameter):
ValueError(f'{parameter} is not an instance of Parameter.')
for i in range(len(self), index, -1):
self.parameters[str(i)] = self.parameters[str(i - 1)]
self.parameters[str(index)] = parameter
return self
def __getitem__(self, index):
if isinstance(index, slice):
return self.__class__(list(self.parameters.values())[index])
index = _get_abs_string_index(self, index)
return self.parameters[index]
def __setitem__(self, index, parameter):
if not isinstance(parameter, Parameter):
ValueError(f'{parameter} is not an instance of Parameter.')
index = _get_abs_string_index(self, index)
self.parameters[str(index)] = parameter
def __delitem__(self, index):
if isinstance(index, slice):
for k in range(len(self.parameters))[index]:
delattr(self, str(k))
else:
delattr(self, _get_abs_string_index(self, index))
indices = [str(i) for i in range(len(self.parameters))]
self._parameters = OrderedDict(
list(zip(indices, self.parameters.values())))
def __len__(self):
return len(self.parameters)
def __iter__(self):
return iter(self.parameters.values())
def __iadd__(self, parameters):
return self.extend(parameters)
[docs]
class Sequential(ModuleList):
r"""A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ordered dict of modules can also be
passed in.
"""
def __init__(self, *args):
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
setattr(self, key, module)
else:
for idx, module in enumerate(args):
setattr(self, str(idx), module)
[docs]
def call(self, input):
for module in self.modules.values():
input = module(input)
return input