# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnabla as nn
import numpy as np
from .search import Searcher
from nnabla_nas.utils.estimator.latency import LatencyEstimator
from nnabla_nas.utils.estimator.latency import LatencyGraphEstimator
[docs]
class ProxylessNasSearcher(Searcher):
r""" ProxylessNAS: Direct Neural Architecture Search on Target Task and
Hardware.
"""
[docs]
def callback_on_start(self):
r"""Gets the architecture parameters."""
self._reward = nn.NdArray.from_numpy_array(np.zeros((1,)))
# load checkpoint if available
self.load_checkpoint()
[docs]
def train_on_batch(self, key='train'):
r"""Update the model parameters."""
self.update_graph(key)
params = self.model.get_net_parameters(grad_only=True)
self.optimizer[key].set_parameters(params)
bz, p = self.mbs_train, self.placeholder['train']
self.optimizer[key].zero_grad()
if self.comm.n_procs > 1:
grads = [x.grad for x in params.values()]
self.event.default_stream_synchronize()
for _ in range(self.accum_train):
self._load_data(p, self.dataloader['train'].next())
p['loss'].forward(clear_no_need_grad=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.monitor.update(f'{k}/train', m.d.copy(), bz)
p['loss'].backward(clear_buffer=True)
loss = p['loss'].d.copy()
self.monitor.update('loss/train', loss * self.accum_train, bz)
if self.comm.n_procs > 1:
self.comm.all_reduce(grads, division=True, inplace=False)
self.event.add_default_stream_event()
self.optimizer[key].update()
[docs]
def valid_on_batch(self):
r"""Update the arch parameters."""
beta, n_iter = 0.9, 10
bz, p = self.mbs_valid, self.placeholder['valid']
valid_data = [self.dataloader['valid'].next()
for i in range(self.accum_valid)]
rewards, grads = [], []
if self.comm.n_procs > 1:
self.event.default_stream_synchronize()
for _ in range(n_iter):
reward = 0
self.update_graph('valid')
arch_params = self.model.get_arch_parameters(grad_only=True)
self.optimizer['valid'].set_parameters(arch_params)
for minibatch in valid_data:
self._load_data(p, minibatch)
p['loss'].forward(clear_buffer=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.monitor.update(f'{k}/valid', m.d.copy(), bz)
loss = p['loss'].d.copy()
reward += (1 - p['metrics']['error'].d) / self.accum_valid
self.monitor.update('loss/valid', loss * self.accum_valid, bz)
# adding constraints
for k, v in self.regularizer.items():
if isinstance(v, LatencyGraphEstimator):
# when using LatencyGraphEstimator (by graph)
inp = [nn.Variable((1,)+si[1:]) for si in
self.model.input_shapes]
out = self.model.call(*inp)
value = v.get_estimation(out)
elif isinstance(v, LatencyEstimator):
# when using LatencyEstimator (by module)
value = v.get_estimation(self.model)
else:
raise NotImplementedError
reward *= (min(1.0, v._bound / value))**v._weight
self.monitor.update(k, value, 1)
rewards.append(reward)
grads.append([m.g.copy() for m in arch_params.values()])
# compute gradients
for j, m in enumerate(arch_params.values()):
m.grad.zero()
for i, r in enumerate(rewards):
m.g += (r - self._reward.data)*grads[i][j]/n_iter
# update global reward
self._reward.data = beta*sum(rewards)/n_iter + \
(1 - beta)*self._reward.data
if self.comm.n_procs > 1:
self.comm.all_reduce(
[x.grad for x in arch_params.values()],
division=True,
inplace=False
)
self.comm.all_reduce(self._reward, division=True, inplace=False)
self.event.add_default_stream_event()
self.monitor.update('reward', self._reward.data[0], self.bs_valid)
self.optimizer['valid'].update()