# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from more_itertools import consume
from pathlib import Path
from os import environ
import json
import nnabla as nn
from ..utils.helper import ProgressMeter, get_output_path
from ..utils.data import transforms
[docs]
class Runner(ABC):
r"""Runner is a basic class for training a model.
You can adapt this class for your own runner by reimplementing the
abstract methods of this class.
Args:
model (`nnabla_nas.contrib.model.Model`): The search model used to
search the architecture.
optimizer (dict): This stores optimizers for both `train` and `valid`
graphs. Must only store instances of `Optinmizer`
regularizer (dict): This stores regularizers such as the latency and memory
estimators
dataloader (dict): This stores dataloaders for both `train` and `valid`
graphs.
hparams (Configuration): This stores all hyperparmeters used during training.
args (Configuration): This stores other variables used during for training:
event, communicator, output_path...
"""
def __init__(self, model, optimizer, regularizer, dataloader, hparams, args):
self.model = model
self.dataloader = dataloader
self.optimizer = optimizer
self.regularizer = regularizer
self.hparams = hparams
self.args = args
# aditional argurments
hp = self.hparams
self.comm = args['comm']
self.event = args['event']
n_procs = self.comm.n_procs
# hp['batch_size_XX'] is the GLOBAL batch size used for train / val
# which is independent on the number of GPUs
# self.bs_XX is the LOCAL batch size (the batch size for each used GPU)
# self.mbs_XX is the MINIBATCH size, the nr. of samples used at once
# for forward pass of the network
self.bs_train = hp['batch_size_train'] // n_procs
self.mbs_train = hp['mini_batch_train']
self.bs_valid = hp['batch_size_valid'] // n_procs
self.mbs_valid = hp['mini_batch_valid']
self.accum_train = self.bs_train // self.mbs_train
self.accum_valid = self.bs_valid // self.mbs_valid
self.one_epoch_train = len(self.dataloader['train']) // self.bs_train
self.one_epoch_valid = len(self.dataloader['valid']) // self.bs_valid
self.cur_epoch = 0
# setup placeholder
def create_variables(mbs, shapes):
return [nn.Variable([mbs] + shape) for shape in shapes]
self.placeholder = {}
self.placeholder['train'] = {
'inputs': create_variables(self.mbs_train, hparams['input_shapes']),
'targets': create_variables(self.mbs_train, hparams['target_shapes'])
}
self.placeholder['valid'] = {
'inputs': create_variables(self.mbs_valid, hparams['input_shapes']),
'targets': create_variables(self.mbs_valid, hparams['target_shapes'])
}
# monitor log info
self._abs_output_path = str(Path(get_output_path(is_abspath=True)) / self.args['output_path'])
self._rel_output_path = str(Path(get_output_path(is_abspath=False)) / self.args['output_path'])
self.monitor = ProgressMeter(self.one_epoch_train, self._abs_output_path,
quiet=self.comm.rank > 0)
# Check if we should run in fast mode where all computation graph is
# kept in memory and mixed operations just switch the propagation path.
fast_mode = environ.get('NNABLA_NAS_MIXEDOP_FAST_MODE') is not None
self.monitor.info('NNABLA_NAS_MIXEDOP_FAST_MODE is {}\n'.format(
'enabled' if fast_mode else 'disabled'))
self._fast_mode = fast_mode
@property
def fast_mode(self):
return self._fast_mode
[docs]
@abstractmethod
def run(self):
r"""Run the training process."""
pass
[docs]
def update_graph(self, key='train'):
r"""Builds the graph and update the placeholder.
Args:
key (str, optional): Type of graph. Defaults to 'train'.
"""
if key not in ('train', 'valid', 'warmup'):
raise ValueError(f'key = {key} is not allowed')
if key in ('train', 'warmup'):
key = 'train'
placeholder = self.placeholder[key]
if self.dataloader[key].transform is None:
self.dataloader[key].transform = 'none_transform'
try:
# self.dataloader[key].transform is a str with the name of the tranformation to apply
func = getattr(transforms, self.dataloader[key].transform)
# transform is the corresponding function in utils/data/transforms.py to be used
transform = func(key)
except AttributeError:
print(
'ERROR: Transformation function \'' +
self.dataloader[key].transform +
'\' NOT defined in ' + transforms.__name__
)
raise AttributeError
training = key == 'train'
model = self.model
# apply data transformations
if not self.fast_mode or 'transformed' not in placeholder:
inputs = placeholder['inputs']
inputs = [transform(x) for x in inputs]
placeholder['transformed'] = inputs
inputs = placeholder['transformed']
# generate a new architecture
model.apply(None, training=training)
outputs = model(*inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)
outputs = [output.apply(persistent=True) for output in outputs]
placeholder['outputs'] = outputs
# add the model's loss function
if not self.fast_mode or 'loss' not in placeholder:
targets = placeholder['targets']
loss_weights = self.hparams['loss_weights']
accum = self.accum_train if training else self.accum_valid
loss = model.loss(outputs, targets, loss_weights) / accum
placeholder['loss'] = loss.apply(persistent=True)
# metrics to monitor during training
if not self.fast_mode or 'metrics' not in placeholder:
targets = placeholder['targets']
outputs = (v.get_unlinked_variable() for v in outputs)
outputs = list(v.apply(need_grad=False) for v in outputs)
metrics = model.metrics(outputs, targets)
consume(v.apply(persistent=True) for v in metrics.values())
placeholder['metrics'] = metrics
@staticmethod
def _load_data(placeholder, data):
for key in ('inputs', 'targets'):
for inp, x in zip(placeholder[key], data[key]):
if isinstance(x, nn.NdArray):
inp.data = x
else:
inp.d = x
[docs]
def save_checkpoint(self, checkpoint_info={}):
r"""Save the current states of the runner."""
path = Path(self._abs_output_path) / 'checkpoint'
path.mkdir(parents=True, exist_ok=True)
relpath = Path(self._rel_output_path) / 'checkpoint'
checkpoint_info['epoch'] = self.cur_epoch
# save optimizers state
checkpoint_info['optimizers'] = dict()
for name, optimizer in self.optimizer.items():
checkpoint_info['optimizers'][name] = optimizer.save_checkpoint(str(relpath), name)
if ("best_metric" in checkpoint_info.keys() and "error" in checkpoint_info["best_metric"].keys()):
checkpoint_info["best_metric"]["error"] = float(checkpoint_info["best_metric"]["error"])
# save parameters
self.model.save_parameters(str(path / 'weights.h5'))
checkpoint_info['params_path'] = str(relpath / 'weights.h5')
with path.joinpath('checkpoint.json').open('w') as f:
json.dump(checkpoint_info, f)
self.monitor.info(f"Checkpoint saved: {str(path)}\n")
[docs]
def load_checkpoint(self):
output_path = get_output_path()
path = Path(output_path) / 'checkpoint' / 'checkpoint.json'
if path.is_file():
# path = os.path.join(path, 'checkpoint.json')
with path.open('r') as f:
checkpoint_info = json.load(f)
self.cur_epoch = checkpoint_info['epoch'] + 1
# load optimizers
for name, optim_info in checkpoint_info['optimizers'].items():
p = self.model.get_parameters()
# make sure that optimizer parameters match
params_names = checkpoint_info['optimizers'][name]['params_names']
params = {k: p[k] for k in params_names}
self.optimizer[name].set_parameters(params)
self.optimizer[name].load_checkpoint(optim_info)
# load parameters
self.model.load_parameters(checkpoint_info['params_path'])
self.monitor.info(f"Checkpoint loaded: {str(path)}\n")
return checkpoint_info
return None
[docs]
@abstractmethod
def train_on_batch(self, key='train'):
r"""Runs the model update on a single batch of train data."""
pass
[docs]
@abstractmethod
def valid_on_batch(self):
r"""Runs the model update on a single batch of valid data."""
pass
[docs]
@abstractmethod
def callback_on_epoch_end(self):
r"""Calls this after one epoch."""
pass
[docs]
@abstractmethod
def callback_on_start(self):
r"""Calls this on starting the run method."""
pass
[docs]
@abstractmethod
def callback_on_finish(self):
r"""Calls this on finishing the run method."""
pass