# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import random
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
import nnabla as nn
import nnabla.functions as F
from ... import contrib
from .search import Searcher
from ...contrib.common.ofa.utils.random_resize_crop import OFAResize
from ...contrib.common.ofa.elastic_nn.utils import set_running_statistics
from ...utils.data import transforms
[docs]
class OFASearcher(Searcher):
r"""An implementation of OFA."""
def __init__(self, model, optimizer, regularizer, dataloader, hparams, args):
super().__init__(model, optimizer, regularizer, dataloader, hparams, args)
manual_seed = 0
nn.seed(manual_seed)
np.random.seed(manual_seed)
random.seed(manual_seed)
self.bs_test = self.bs_valid
self.mbs_test = self.mbs_valid
self.accum_test = self.bs_test // self.mbs_test
self.one_epoch_test = len(self.dataloader['test']) // self.bs_test
self.image_size_list = self.hparams['train_image_size_list']
OFAResize.IMAGE_SIZE_LIST = self.image_size_list
OFAResize.ACTIVE_SIZE = max(self.image_size_list)
OFAResize.IMAGE_SIZE_SEG = 4 if 'image_size_seg' not in self.hparams else self.hparams['image_size_seg']
OFAResize.CONTINUOUS = True if "image_size_continuous" not in self.hparams \
else self.hparams['image_size_continuous']
if self.hparams['lambda_kd'] > 0: # knowledge distillation
name, attributes = list(self.hparams['teacher_network'].items())[0]
self.teacher_model = contrib.__dict__[name].TrainNet(**attributes)
self.update_graph('valid')
self.metrics = {
k: nn.NdArray.from_numpy_array(np.zeros((1,)))
for k in self.placeholder['valid']['metrics']
}
# loss and metric
self.loss = nn.NdArray.from_numpy_array(np.zeros((1,)))
[docs]
def run(self):
r"""Run the training process."""
self.callback_on_start()
# Test for init parameters
if self.hparams['task'] != 'fullnet':
self.valid_genotypes(mode='test')
# training
for self.cur_epoch in range(self.cur_epoch, self.hparams['epoch']):
self.monitor.reset()
OFAResize.IS_TRAINING = True
lr = self.optimizer['train'].get_learning_rate()
self.monitor.info(f'Running epoch={self.cur_epoch}\tlr={lr:.5f}\n')
OFAResize.EPOCH = self.cur_epoch
for i in range(self.one_epoch_train):
self.train_on_batch(self.cur_epoch, i)
if i % (self.args['print_frequency']) == 0:
train_keys = [m.name for m in self.monitor.meters.values()
if 'train' in m.name]
self.monitor.display(i, key=train_keys)
if self.cur_epoch % self.hparams["validation_frequency"] == 0:
self.valid_genotypes(mode='valid')
return self
[docs]
def callback_on_start(self):
keys = self.hparams['no_decay_keys']
net_params = [
self.get_net_parameters_with_keys(keys, mode='exclude', grad_only=True), # parameters with weight decay
self.get_net_parameters_with_keys(keys, mode='include', grad_only=True), # parameters without weight decay
]
self.optimizer['train'].set_parameters(net_params[0])
self.optimizer['train_no_decay'].set_parameters(net_params[1])
# load checkpoint if available
self.load_checkpoint()
if self.comm.n_procs > 1:
self._grads_net = [x.grad for x in net_params[0].values()]
self._grads_no_decay_net = [x.grad for x in net_params[1].values()]
self.event.default_stream_synchronize()
[docs]
def train_on_batch(self, epoch, n_iter, key='train'):
r"""Update the model parameters."""
OFAResize.BATCH = n_iter
batch = [self.dataloader['train'].next()
for _ in range(self.accum_train)]
bz, p = self.mbs_train, self.placeholder['train']
if key == 'train':
self.optimizer['train'].zero_grad()
self.optimizer['train_no_decay'].zero_grad()
else:
self.optimizer[key].zero_grad()
if self.comm.n_procs > 1:
self.event.default_stream_synchronize()
self.update_graph(key)
for _, data in enumerate(batch):
self._load_data(p, data)
p['loss'].forward(clear_no_need_grad=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.monitor.update(f'{k}/train', m.d.copy(), bz)
p['loss'].backward(clear_buffer=True)
loss = p['loss'].d.copy()
self.monitor.update('loss/train', loss * self.accum_train, bz)
if self.comm.n_procs > 1:
self.comm.all_reduce(self._grads_net, division=True, inplace=False)
self.comm.all_reduce(self._grads_no_decay_net, division=True, inplace=False)
self.event.add_default_stream_event()
if key != 'train':
self.optimizer[key].update()
else:
self.optimizer['train'].update()
self.optimizer['train_no_decay'].update()
[docs]
def valid_on_batch(self, is_test=False):
r"""Updates the architecture parameters."""
key = 'test' if is_test else 'valid'
bz = self.mbs_test if is_test else self.mbs_valid
accum = self.accum_test if is_test else self.accum_valid
p = self.placeholder['valid']
if self.comm.n_procs > 1:
self.event.default_stream_synchronize()
for _ in range(accum):
self._load_data(p, self.dataloader[key].next())
p['loss'].forward(clear_buffer=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.metrics[k].data += m.d.copy() * bz
loss = p['loss'].d.copy()
self.loss.data += loss * accum * bz
if self.comm.n_procs > 1:
self.comm.all_reduce(
[self.loss] + list(self.metrics.values()), division=True, inplace=False)
self.event.add_default_stream_event()
[docs]
def valid_genotypes(self, mode='valid'):
assert mode in ['valid', 'test']
is_test = True if mode == 'test' else False
OFAResize.IS_TRAINING = False
for genotype in self.hparams['valid_genotypes']:
for img_size in self.hparams['valid_image_size_list']:
self.monitor.reset()
OFAResize.ACTIVE_SIZE = img_size
self.model.set_valid_arch(genotype)
self.reset_running_statistics()
for _ in tqdm(range(self.one_epoch_valid if mode == 'valid' else self.one_epoch_test),
desc=f'{mode} [{self.cur_epoch}/{self.hparams["epoch"]}]'):
self.update_graph(mode)
self.valid_on_batch(is_test=is_test)
self.monitor.info(f'img_size={img_size}, genotype={genotype} \n')
self.callback_on_epoch_end(is_test=is_test)
self.monitor.write(self.cur_epoch)
self.loss.zero()
for k in self.metrics:
self.metrics[k].zero()
[docs]
def callback_on_epoch_end(self, epoch=None, is_test=False, info=None):
if is_test:
num_of_samples = self.one_epoch_test * self.accum_test * self.mbs_test
else:
num_of_samples = self.one_epoch_valid * self.accum_valid * self.mbs_valid
self.loss.data /= num_of_samples
for k in self.metrics:
self.metrics[k].data /= num_of_samples
if self.comm.rank == 0:
self.monitor.update('loss/valid', self.loss.data[0], 1)
self.monitor.info(f'loss/valid={self.loss.data[0]:.4f}\n')
for k in self.metrics:
self.monitor.update(f'{k}/valid', self.metrics[k].data[0], 1)
self.monitor.info(f'{k}/valid={self.metrics[k].data[0]:.4f}\n')
if info:
self.monitor.info(f'{info}\n')
if self.args['save_nnp']:
self.model.save_net_nnp(
self._abs_output_path,
self.placeholder['valid']['inputs'][0],
self.placeholder['valid']['outputs'][0],
save_params=self.args.get('save_params'))
if not is_test:
self.save_checkpoint()
[docs]
def update_graph(self, key='train'):
r"""Builds the graph and update the placeholder.
Args:
key (str, optional): Type of graph. Defaults to 'train'.
"""
assert key in ('train', 'valid', 'test')
self.model.apply(training=key not in ['valid', 'test'])
if self.hparams['lambda_kd'] > 0:
self.teacher_model.apply(training='train')
fake_key = 'train' if key == 'train' else 'valid'
p = self.placeholder[fake_key]
resize = OFAResize()
if self.dataloader[fake_key].transform is None:
self.dataloader[fake_key].transform = 'none_transform'
try:
func = getattr(transforms, self.dataloader[fake_key].transform)
transform = func(fake_key)
except AttributeError:
print(
'ERROR: Transformation function \'' +
self.dataloader[fake_key].transform +
'\' NOT defined in ' + transforms.__name__
)
raise AttributeError
accum = self.accum_test if key == 'test' else (self.accum_valid if key == 'valid' else self.accum_train)
# outputs
inputs = [resize(transform(x)) for x in p['inputs']]
outputs = self.model(*inputs)
outputs = outputs if isinstance(outputs, (tuple, list)) else (outputs,)
p['outputs'] = [x.apply(persistent=True) for x in outputs]
if fake_key == 'valid' and self.hparams['valid_ce_loss']:
# cross entropy loss
p['loss'] = F.mean(F.softmax_cross_entropy(p['outputs'][0], p['targets'][0])) / accum
else:
if self.hparams['lambda_kd'] > 0:
with nn.no_grad():
soft_logits = self.teacher_model(*inputs)
soft_logits = soft_logits if isinstance(soft_logits, (tuple, list)) else (soft_logits,)
p['soft_logits'] = [x.apply(need_grad=False) for x in soft_logits]
kd_loss = self.model.kd_loss(
p['outputs'], p['soft_logits'], p['targets'], self.hparams['loss_weights'])
# loss function
if self.hparams['lambda_kd'] > 0:
p['loss'] = (self.model.loss(p['outputs'], p['targets'], self.hparams['loss_weights'])
+ self.hparams['lambda_kd'] * kd_loss) / accum
else:
p['loss'] = self.model.loss(p['outputs'], p['targets'], self.hparams['loss_weights']) / accum
p['loss'].apply(persistent=True)
# metrics to monitor during training
targets = [out.get_unlinked_variable().apply(need_grad=False) for out in p['outputs']]
p['metrics'] = self.model.metrics(targets, p['targets'])
for v in p['metrics'].values():
v.apply(persistent=True)
[docs]
def get_net_parameters_with_keys(self, keys, mode='include', grad_only=False):
r"""Returns an `OrderedDict` containing model parameters.
Args:
keys (list of str): Patterns of parameters to be considered for inclusion
or exclusion. Note: Keys passed must be in regular expression format.
mode (str, optional): Mode of getting network parameters with keys.
- Selects parameters satisfying the keys if mode=='include'
- Selects parameters not satisfying the keys if mode=='exclude'
Choices: ['include', 'exclude']. Defaults to 'include'.
grad_only (bool, optional): If sets to `True`, then only parameters
with `need_grad=True` are returned. Defaults to False.
Returns:
OrderedDict: A dictionary containing parameters.
"""
pattern = re.compile('|'.join(keys)) # compile the pattern of all keys
net_params = self.model.get_net_parameters(grad_only)
if mode == 'include': # without weight decay
param_dict = OrderedDict()
for name in net_params.keys():
if re.search(pattern, name) is not None:
param_dict[name] = net_params[name]
return param_dict
elif mode == 'exclude': # with weight decay
param_dict = OrderedDict()
for name in net_params.keys():
if re.search(pattern, name) is None:
param_dict[name] = net_params[name]
return param_dict
else:
raise ValueError('do not support %s' % mode)
[docs]
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200,
dataloader=None, dataloader_batch_size=None, inp_shape=None):
if net is None:
net = self.model
if dataloader is None:
subset_train_dataloader = self.dataloader['train']
dataloader_batch_size = self.mbs_train
if inp_shape is None:
inp_shape = self.hparams['input_shapes']
set_running_statistics(net, subset_train_dataloader, dataloader_batch_size,
subset_size, subset_batch_size, inp_shape)
[docs]
def callback_on_finish(self):
pass