# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
from collections import OrderedDict
import json
import os
import nnabla.functions as F
from nnabla.initializer import ConstantInitializer
import numpy as np
from . import modules as darts
from .... import module as Mo
from ..base import ClassificationModel as Model
from ..misc import AuxiliaryHeadCIFAR
from .helper import save_dart_arch, visualize_dart_arch
from hydra import utils
[docs]
class SearchNet(Model):
r"""DARTS: Differentiable Architecture Search.
This is the search space for DARTS.
Args:
in_channels (int): The number of input channels.
init_channels (int): The initial number of channels on each cell.
num_cells (int): The number of cells.
num_classes (int): The number of classes.
num_choices (int, optional): The number of choice blocks on each cell.
Defaults to 4.
multiplier (int, optional): The multiplier. Defaults to 4.
mode (str, optional): The sampling strategy ('full', 'max', 'sample').
Defaults to 'full'.
shared (bool, optional): If parameters are shared between cells.
Defaults to False.
stem_multiplier (int, optional): The multiplier used for stem
convolution. Defaults to 3.
"""
def __init__(self, in_channels, init_channels, num_cells, num_classes,
num_choices=4, multiplier=4, mode='full', shared=False,
stem_multiplier=3):
self._in_channels = in_channels
self._init_channels = init_channels
self._num_cells = num_cells
self._num_classes = num_classes
self._num_choices = num_choices
self._multiplier = multiplier
self._mode = mode
self._shared = shared
num_channels = stem_multiplier * init_channels
# initialize the arch parameters
self._alpha = self._init_alpha()
# build the network
self._stem = darts.StemConv(in_channels, num_channels)
self._cells = self._init_cells(num_cells, num_channels)
self._ave_pool = Mo.AvgPool(kernel=(8, 8))
self._linear = Mo.Linear(self._last_channels, num_classes)
[docs]
def call(self, input):
out_p = out_c = self._stem(input)
for cell in self._cells:
out_c, out_p = cell(out_p, out_c), out_c
out_c = self._ave_pool(out_c)
return self._linear(out_c)
def _init_cells(self, num_cells, C):
"""Initializes the cells used in DARTS.
Args:
num_cells (int): The number of cells.
C (int): The number of channels.
Returns:
ModuleList: List of cells.
"""
cells = Mo.ModuleList()
Cpp, Cp, C = C, C, self._init_channels
reduction_p, reduction_c = False, False
for i in range(num_cells):
reduction_c = i in (num_cells // 3, 2 * num_cells // 3)
C *= reduction_c + 1
cells.append(
darts.Cell(
num_choices=self._num_choices,
multiplier=self._multiplier,
channels=(Cpp, Cp, C),
reductions=(reduction_p, reduction_c),
mode=self._mode,
alpha=self._alpha[reduction_c] if self._shared else None
)
)
reduction_p = reduction_c
Cpp, Cp = Cp, self._multiplier * C
if i == 2 * num_cells // 3:
self._c_auxiliary = Cp
# save the last channels for the last module
self._last_channels = Cp
return cells
def _init_alpha(self):
r"""Returns a list of alpha parameters.
Returns:
ModuleList: List of alpha parameters. The first is used in a normal
cell and the second is used in the reduction cell.
"""
shape = (len(darts.CANDIDATES), 1, 1, 1, 1)
init = ConstantInitializer(0.0)
n = self._num_choices * (self._num_choices + 3) // 2
alpha = Mo.ModuleList()
if self._shared:
for _ in range(2):
params = Mo.ParameterList()
for _ in range(n):
params.append(Mo.Parameter(shape, initializer=init))
alpha.append(params)
return alpha
[docs]
def get_net_parameters(self, grad_only=False):
r"""Returns an `OrderedDict` containing model parameters.
Args:
grad_only (bool, optional): If sets to `True`, then only parameters
with `need_grad=True` are returned. Defaults to False.
Returns:
OrderedDict: A dictionary containing parameters.
"""
p = self.get_parameters(grad_only)
return OrderedDict([(k, v) for k, v in p.items() if 'alpha' not in k])
[docs]
def get_arch_parameters(self, grad_only=False):
r"""Returns an `OrderedDict` containing architecture parameters.
Args:
grad_only (bool, optional): If sets to `True`, then only parameters
with `need_grad=True` are returned. Defaults to False.
Returns:
OrderedDict: A dictionary containing parameters.
"""
if self._shared:
return self._alpha.get_parameters()
p = self.get_parameters(grad_only)
return OrderedDict([(k, v) for k, v in p.items() if 'alpha' in k])
[docs]
def summary(self):
r"""Summary of the model."""
stats = []
arch_params = self.get_arch_parameters()
count = Counter([np.argmax(m.d.flat) for m in arch_params.values()])
op_names = list(darts.CANDIDATES.keys())
total = len(arch_params)
for k in range(len(op_names)):
name = op_names[k]
stats.append(name + f' = {count[k]/total*100:.2f}%\t')
return ''.join(stats)
[docs]
def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
if self._shared:
# save the architectures
output_path = os.path.dirname(path)
save_dart_arch(self, output_path)
[docs]
def save_net_nnp(self, path, inp, out, calc_latency=False,
func_real_latency=None, func_accum_latency=None,
save_params=None):
super().save_net_nnp(path, inp, out, calc_latency=False,
func_real_latency=func_real_latency,
func_accum_latency=func_accum_latency,
save_params=save_params)
if self._shared:
# save the architectures
save_dart_arch(self, path)
[docs]
def visualize(self, path):
if self._shared:
# save the architectures
visualize_dart_arch(path)
[docs]
def loss(self, outputs, targets, loss_weights=None):
loss = F.mean(F.softmax_cross_entropy(outputs[0], targets[0]))
if len(outputs) == 2: # use auxiliar head
loss_weights = loss_weights or (1.0, 1.0)
aux_loss = F.mean(F.softmax_cross_entropy(outputs[1], targets[0]))
loss = loss_weights[0] * loss + loss_weights[1] * aux_loss
return loss
[docs]
class TrainNet(Model):
"""TrainNet used for DARTS."""
def __init__(self, in_channels, init_channels, num_cells, num_classes,
genotype, num_choices=4, multiplier=4, stem_multiplier=3,
drop_path=0, auxiliary=False):
self._num_ops = len(darts.CANDIDATES)
self._multiplier = multiplier
self._init_channels = init_channels
self._num_cells = num_cells
self._auxiliary = auxiliary
self._num_choices = num_choices
self._drop_path = drop_path
num_channels = stem_multiplier * init_channels
genotype_path = os.path.realpath(os.path.join(utils.get_original_cwd(), genotype))
genotype = json.load(open(genotype_path, 'r'))
# initialize the arch parameters
self._stem = darts.StemConv(in_channels, num_channels)
self._cells = self._init_cells(num_cells, num_channels, genotype)
self._ave_pool = Mo.AvgPool(kernel=(8, 8))
self._linear = Mo.Linear(self._last_channels, num_classes)
# auxiliary head
if auxiliary:
self._auxiliary_head = AuxiliaryHeadCIFAR(
self._c_auxiliary, num_classes)
[docs]
def call(self, input):
logits_aux = None
out_p = out_c = self._stem(input)
for i, cell in enumerate(self._cells):
out_p, out_c = out_c, cell(out_p, out_c)
if i == 2 * self._num_cells//3:
if self.training and self._auxiliary:
logits_aux = self._auxiliary_head(out_c)
out_c = self._ave_pool(out_c)
logits = self._linear(out_c)
return logits if logits_aux is None else (logits, logits_aux)
def _init_cells(self, num_cells, channel_c, genotype):
cells = Mo.ModuleList()
channel_p_p, channel_p, channel_c = channel_c, channel_c,\
self._init_channels
reduction_p, reduction_c = False, False
for i in range(num_cells):
reduction_c = i in (num_cells // 3, 2 * num_cells // 3)
channel_c *= reduction_c + 1
cells.append(
Cell(channels=(channel_p_p, channel_p, channel_c),
reductions=(reduction_p, reduction_c),
genotype=genotype,
drop_path=self._drop_path)
)
reduction_p = reduction_c
channel_p_p, channel_p = channel_p, self._multiplier * channel_c
if i == 2*num_cells//3:
self._c_auxiliary = channel_p
# save the last channels for the last module
self._last_channels = channel_p
return cells
[docs]
def loss(self, outputs, targets, loss_weights=None):
loss = F.mean(F.softmax_cross_entropy(outputs[0], targets[0]))
if len(outputs) == 2: # use auxiliar head
loss_weights = loss_weights or (1.0, 1.0)
aux_loss = F.mean(F.softmax_cross_entropy(outputs[1], targets[0]))
loss = loss_weights[0] * loss + loss_weights[1] * aux_loss
return loss
class Cell(Mo.Module):
def __init__(self, channels, reductions, genotype, drop_path):
self._drop_path = drop_path
# preprocess the inputs
self._prep = Mo.ModuleList()
if reductions[0]:
self._prep.append(
darts.FactorizedReduce(channels[0], channels[2]))
else:
self._prep.append(darts.ReLUConvBN(
channels[0], channels[2], kernel=(1, 1)))
self._prep.append(darts.ReLUConvBN(
channels[1], channels[2], kernel=(1, 1)))
cell_type = 'reduce' if reductions[-1] else 'normal'
cell_arch = genotype[cell_type + '_alpha']
# build choice blocks
self._indices = list()
self._blocks = Mo.ModuleList()
candidates = list(darts.CANDIDATES.values())
for i in range(len(cell_arch)):
for (op_idx, choice_idx) in cell_arch[str(i + 2)]:
stride = 2 if reductions[-1] and choice_idx < 2 else 1
self._blocks.append(candidates[op_idx](channels[2], stride))
self._indices.append(choice_idx)
def call(self, *input):
"""Each cell has two inputs and one output."""
out = [op(x) for op, x in zip(self._prep, input)]
for i in range(len(self._indices) // 2):
idx = (self._indices[2*i], self._indices[2*i + 1])
ops = (self._blocks[2*i], self._blocks[2*i + 1])
choice = list()
for j, op in zip(idx, ops):
choice.append(op(out[j]))
if self.training and not isinstance(op, Mo.Identity):
choice[-1] = darts.DropPath(self._drop_path)(choice[-1])
out.append(F.add2(choice[0], choice[1]))
return F.concatenate(*out[2:], axis=1)