proper checkpoint implementation (#1043)
* enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * name formatting * version * testing * add test * fix test * Update model_checkpoint.py * doctests * pylint * tests * debug * debug * enabled early stopping/checkpooiunt even without val step * fix MNIST download (#1044) * fix MNIST download * simple * name formatting * version * testing * add test * fix test * doctests * tests * debug * debug * rebased 1041 * rebased 1041 * tests * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
165b9fb3f3
commit
bcb45d906d
|
@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
|
||||
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
|
||||
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
|
||||
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -1,18 +1,12 @@
|
|||
r"""
|
||||
Model Checkpoint
|
||||
==============
|
||||
Save the model as often as requested.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import logging as log
|
||||
import warnings
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import Callback
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
|
@ -20,21 +14,23 @@ class ModelCheckpoint(Callback):
|
|||
Save the model after every epoch.
|
||||
|
||||
Args:
|
||||
dirpath: path to save the model file.
|
||||
filepath: path to save the model file.
|
||||
Can contain named formatting options to be auto-filled.
|
||||
|
||||
Example::
|
||||
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
|
||||
# no path
|
||||
ModelCheckpoint()
|
||||
# saves like /my/path/epoch_0.ckpt
|
||||
|
||||
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
|
||||
# save any arbitrary metrics like and val_loss, etc in name
|
||||
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
|
||||
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
|
||||
|
||||
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode, False or True.
|
||||
save_top_k: if `save_top_k == k`,
|
||||
monitor (str): quantity to monitor.
|
||||
verbose (bool): verbosity mode, False or True.
|
||||
save_top_k (int): if `save_top_k == k`,
|
||||
the best k models according to
|
||||
the quantity monitored will be saved.
|
||||
if ``save_top_k == 0``, no models are saved.
|
||||
|
@ -43,7 +39,7 @@ class ModelCheckpoint(Callback):
|
|||
if ``save_top_k >= 2`` and the callback is called multiple
|
||||
times inside an epoch, the name of the saved file will be
|
||||
appended with a version count starting with `v0`.
|
||||
mode: one of {auto, min, max}.
|
||||
mode (str): one of {auto, min, max}.
|
||||
If ``save_top_k != 0``, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
|
@ -51,46 +47,43 @@ class ModelCheckpoint(Callback):
|
|||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only: if True, then only the model's weights will be
|
||||
save_weights_only (bool): if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period: Interval (number of epochs) between checkpoints.
|
||||
prefix: String name for particular model
|
||||
period (int): Interval (number of epochs) between checkpoints.
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||||
checkpoint_callback = ModelCheckpoint('my_path')
|
||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||||
Trainer(checkpoint_callback=checkpoint_callback)
|
||||
"""
|
||||
#: checkpoint extension
|
||||
EXTENSION = '.ckpt'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirpath: str,
|
||||
monitor: str = 'val_loss',
|
||||
verbose: bool = False,
|
||||
save_top_k: int = 1,
|
||||
save_weights_only: bool = False,
|
||||
mode: str = 'auto',
|
||||
period: int = 1,
|
||||
prefix: str = ''
|
||||
):
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
|
||||
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
||||
save_top_k: int = 1, save_weights_only: bool = False,
|
||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||
super().__init__()
|
||||
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
|
||||
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
|
||||
warnings.warn(
|
||||
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
|
||||
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||||
"All files in this directory will be deleted when a checkpoint is saved!"
|
||||
)
|
||||
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.dirpath = dirpath
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
if os.path.isdir(filepath):
|
||||
self.dirpath, self.filename = filepath, '{epoch}'
|
||||
else:
|
||||
self.dirpath, self.filename = os.path.split(filepath)
|
||||
|
||||
os.makedirs(self.dirpath, exist_ok=True)
|
||||
self.save_top_k = save_top_k
|
||||
self.save_weights_only = save_weights_only
|
||||
self.period = period
|
||||
|
@ -102,14 +95,6 @@ class ModelCheckpoint(Callback):
|
|||
self.best = 0
|
||||
self.save_function = None
|
||||
|
||||
# this create unique prefix if the give already exists
|
||||
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
|
||||
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
|
||||
version_cnt = 0
|
||||
while self.prefix in existing_names:
|
||||
self.prefix = f'{prefix}-v{version_cnt}'
|
||||
version_cnt += 1
|
||||
|
||||
mode_dict = {
|
||||
'min': (np.less, np.Inf, 'min'),
|
||||
'max': (np.greater, -np.Inf, 'max'),
|
||||
|
@ -125,39 +110,65 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
|
||||
|
||||
def _del_model(self, filepath: str) -> None:
|
||||
# shutil.rmtree(filepath)
|
||||
def _del_model(self, filepath):
|
||||
os.remove(filepath)
|
||||
|
||||
def _save_model(self, filepath: str) -> None:
|
||||
def _save_model(self, filepath):
|
||||
# make paths
|
||||
os.makedirs(self.dirpath, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# delegate the saving to the model
|
||||
if self.save_function is not None:
|
||||
self.save_function(filepath)
|
||||
else:
|
||||
raise ValueError("Method `.save_function()` not set")
|
||||
raise ValueError(".save_function() not set")
|
||||
|
||||
def check_monitor_top_k(self, current: float) -> bool:
|
||||
def check_monitor_top_k(self, current):
|
||||
less_than_k_models = len(self.best_k_models) < self.save_top_k
|
||||
if less_than_k_models:
|
||||
return True
|
||||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||
|
||||
def _get_available_filepath(self, current: float, epoch: int) -> str:
|
||||
current_str = f'{current:.2f}' if current else 'NaN'
|
||||
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
|
||||
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
|
||||
assert not os.path.isfile(filepath)
|
||||
def format_checkpoint_name(self, epoch, metrics, ver=None):
|
||||
"""Generate a filename according define template.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> tmpdir = os.path.dirname(__file__)
|
||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
'epoch=0.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
|
||||
'epoch=005.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
|
||||
'epoch=2-val_loss=0.12.ckpt'
|
||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
'missing=0.ckpt'
|
||||
"""
|
||||
# check if user passed in keys to the string
|
||||
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
|
||||
|
||||
if len(groups) == 0:
|
||||
# default name
|
||||
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
|
||||
else:
|
||||
metrics['epoch'] = epoch
|
||||
filename = self.filename
|
||||
for tmp in groups:
|
||||
name = tmp[1:]
|
||||
filename = filename.replace(tmp, name + '={' + name)
|
||||
if name not in metrics:
|
||||
metrics[name] = 0
|
||||
filename = filename.format(**metrics)
|
||||
str_ver = f'_v{ver}' if ver is not None else ''
|
||||
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
|
||||
return filepath
|
||||
|
||||
def on_validation_end(self, trainer, pl_module) -> None:
|
||||
# only run on main process
|
||||
if trainer.proc_rank != 0:
|
||||
return
|
||||
|
||||
logs = trainer.callback_metrics
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
metrics = trainer.callback_metrics
|
||||
epoch = trainer.current_epoch
|
||||
self.epochs_since_last_check += 1
|
||||
|
||||
|
@ -166,27 +177,36 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
if self.epochs_since_last_check >= self.period:
|
||||
self.epochs_since_last_check = 0
|
||||
current = logs.get(self.monitor)
|
||||
filepath = self._get_available_filepath(current, epoch)
|
||||
|
||||
filepath = self.format_checkpoint_name(epoch, metrics)
|
||||
version_cnt = 0
|
||||
while os.path.isfile(filepath):
|
||||
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
|
||||
# this epoch called before
|
||||
version_cnt += 1
|
||||
|
||||
if self.save_top_k != -1:
|
||||
current = metrics.get(self.monitor)
|
||||
|
||||
if current is None:
|
||||
warnings.warn(f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
warnings.warn(
|
||||
f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
else:
|
||||
if self.check_monitor_top_k(current):
|
||||
self._do_check_save(filepath, current, epoch)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor}'
|
||||
f' was not in top {self.save_top_k}')
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info('Epoch %05d: saving model to %s', epoch, filepath)
|
||||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||||
self._save_model(filepath)
|
||||
|
||||
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
|
||||
def _do_check_save(self, filepath, current, epoch):
|
||||
# remove kth
|
||||
if len(self.best_k_models) == self.save_top_k:
|
||||
delpath = self.kth_best_model
|
||||
|
@ -205,6 +225,8 @@ class ModelCheckpoint(Callback):
|
|||
self.best = _op(self.best_k_models.values())
|
||||
|
||||
if self.verbose > 0:
|
||||
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
|
||||
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||||
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
||||
f' {filepath} as top {self.save_top_k}')
|
||||
self._save_model(filepath)
|
||||
|
|
|
@ -68,19 +68,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
#: True if using amp
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def hparams(self) -> Namespace:
|
||||
if not hasattr(self, '_hparams'):
|
||||
return Namespace()
|
||||
assert isinstance(self._hparams, dict)
|
||||
return Namespace(**self._hparams)
|
||||
|
||||
@hparams.setter
|
||||
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
"""Set the model hyper-parameters."""
|
||||
if isinstance(params, Namespace):
|
||||
params = vars(params)
|
||||
self._hparams = params
|
||||
self.hparams = None
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
r"""
|
||||
|
|
|
@ -46,6 +46,10 @@ class LightningLoggerBase(ABC):
|
|||
# in case converting from namespace
|
||||
if isinstance(params, Namespace):
|
||||
params = vars(params)
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
return params
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -48,9 +48,15 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
else:
|
||||
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
|
||||
|
||||
# when no val step is defined, use 'loss' otherwise 'val_loss'
|
||||
train_step_only = not self.is_overriden('validation_step')
|
||||
monitor_key = 'loss' if train_step_only else 'val_loss'
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
os.makedirs(ckpt_path, exist_ok=True)
|
||||
self.checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=ckpt_path
|
||||
filepath=ckpt_path,
|
||||
monitor=monitor_key
|
||||
)
|
||||
elif self.checkpoint_callback is False:
|
||||
self.checkpoint_callback = None
|
||||
|
|
|
@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
process_output: ...
|
||||
training_tqdm_dict: ...
|
||||
proc_rank: int
|
||||
checkpoint_callback: ...
|
||||
current_epoch: int
|
||||
callback_metrics: ...
|
||||
test_dataloaders: DataLoader
|
||||
|
@ -377,11 +376,6 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# Validation/Test end callbacks
|
||||
if test_mode:
|
||||
self.on_test_end()
|
||||
else:
|
||||
# model checkpointing
|
||||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.on_validation_end(self, self.get_model())
|
||||
self.on_validation_end()
|
||||
|
||||
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
|
||||
# make dataloader_idx arg in validation_step optional
|
||||
|
|
|
@ -1132,9 +1132,6 @@ class Trainer(TrainerIOMixin,
|
|||
# wait for all processes to catch up
|
||||
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
|
||||
|
||||
# set up checkpoint callback
|
||||
self.configure_checkpoint_callback()
|
||||
|
||||
# register auto-resubmit when on SLURM
|
||||
self.register_slurm_signal_handlers()
|
||||
|
||||
|
@ -1151,6 +1148,9 @@ class Trainer(TrainerIOMixin,
|
|||
# if cluster resets state, the model will update with the saved weights
|
||||
self.model = model
|
||||
|
||||
# set up checkpoint callback
|
||||
self.configure_checkpoint_callback()
|
||||
|
||||
# restore training and model before hpc call
|
||||
self.restore_weights(model)
|
||||
|
||||
|
|
|
@ -165,14 +165,15 @@ class TrainerIOMixin(ABC):
|
|||
def save_checkpoint(self, filepath):
|
||||
checkpoint = self.dump_checkpoint()
|
||||
|
||||
# do the actual save
|
||||
try:
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
except AttributeError:
|
||||
if 'hparams' in checkpoint:
|
||||
del checkpoint['hparams']
|
||||
if self.proc_rank == 0:
|
||||
# do the actual save
|
||||
try:
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
except AttributeError:
|
||||
if 'hparams' in checkpoint:
|
||||
del checkpoint['hparams']
|
||||
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
|
||||
def restore(self, checkpoint_path, on_gpu):
|
||||
"""
|
||||
|
|
|
@ -203,6 +203,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
max_steps: int
|
||||
max_steps: int
|
||||
total_batch_idx: int
|
||||
checkpoint_callback: ...
|
||||
|
||||
# Callback system
|
||||
callbacks: List[Callback]
|
||||
|
@ -212,6 +213,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
on_batch_end: Callable
|
||||
on_epoch_start: Callable
|
||||
on_epoch_end: Callable
|
||||
on_validation_end: Callable
|
||||
|
||||
@property
|
||||
def max_nb_epochs(self):
|
||||
|
@ -454,9 +456,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.fast_dev_run or should_check_val:
|
||||
self.run_evaluation(test_mode=self.testing)
|
||||
|
||||
if self.enable_early_stop:
|
||||
self.early_stop_callback.check_metrics(self.callback_metrics)
|
||||
|
||||
# when logs should be saved
|
||||
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
|
||||
if should_save_log or self.fast_dev_run:
|
||||
|
@ -469,6 +468,17 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# logs user requested information to logger
|
||||
self.log_metrics(batch_step_metrics, grad_norm_dic)
|
||||
|
||||
# ---------------
|
||||
# CHECKPOINTING, EARLY STOPPING
|
||||
# ---------------
|
||||
# save checkpoint even when no test or val step are defined
|
||||
train_step_only = not self.is_overriden('validation_step')
|
||||
if self.fast_dev_run or should_check_val or train_step_only:
|
||||
self.call_checkpoint_callback()
|
||||
|
||||
if self.enable_early_stop:
|
||||
self.early_stop_callback.check_metrics(self.callback_metrics)
|
||||
|
||||
# progress global step according to grads progress
|
||||
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
|
||||
self.global_step += 1
|
||||
|
@ -705,3 +715,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
output = self.process_output(output, train=True)
|
||||
|
||||
return output
|
||||
|
||||
def call_checkpoint_callback(self):
|
||||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.on_validation_end(self, self.get_model())
|
||||
self.on_validation_end()
|
||||
|
|
|
@ -46,6 +46,7 @@ class DictHparamsModel(LightningModule):
|
|||
|
||||
def __init__(self, hparams: Dict):
|
||||
super(DictHparamsModel, self).__init__()
|
||||
self.hparams = hparams
|
||||
self.l1 = torch.nn.Linear(hparams.get('in_features'), hparams['out_features'])
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -239,5 +239,6 @@ def set_random_master_port():
|
|||
def init_checkpoint_callback(logger, path_dir=None):
|
||||
exp_path = get_data_path(logger, path_dir=path_dir)
|
||||
ckpt_dir = os.path.join(exp_path, 'checkpoints')
|
||||
os.mkdir(ckpt_dir)
|
||||
checkpoint = ModelCheckpoint(ckpt_dir)
|
||||
return checkpoint
|
||||
|
|
|
@ -256,66 +256,57 @@ def mocked_device_count_0(monkeypatch):
|
|||
monkeypatch.setattr(torch.cuda, 'device_count', device_count)
|
||||
|
||||
|
||||
test_num_gpus_data = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [
|
||||
pytest.param(None, 0, None, id="None - expect 0 gpu to use."),
|
||||
pytest.param(0, 0, None, id="Oth gpu, expect 1 gpu to use."),
|
||||
pytest.param(1, 1, None, id="1st gpu, expect 1 gpu to use."),
|
||||
pytest.param(-1, PRETEND_N_OF_GPUS, "ddp", id="-1 - use all gpus"),
|
||||
pytest.param('-1', PRETEND_N_OF_GPUS, "ddp", id="'-1' - use all gpus"),
|
||||
pytest.param(3, 3, "ddp", id="3rd gpu - 1 gpu to use (backend:ddp)")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data)
|
||||
])
|
||||
def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distributed_backend):
|
||||
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus
|
||||
|
||||
|
||||
test_num_gpus_data_0 = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [
|
||||
pytest.param(None, 0, None, id="None - expect 0 gpu to use."),
|
||||
pytest.param(None, 0, "ddp", id="None - expect 0 gpu to use."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data_0)
|
||||
])
|
||||
def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distributed_backend):
|
||||
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus
|
||||
|
||||
|
||||
test_root_gpu_data = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], [
|
||||
pytest.param(None, None, "ddp", id="None is None"),
|
||||
pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."),
|
||||
pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."),
|
||||
pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."),
|
||||
pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."),
|
||||
pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data)
|
||||
pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")
|
||||
])
|
||||
def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend):
|
||||
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu
|
||||
|
||||
|
||||
test_root_gpu_data_for_0_devices_passing = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize([
|
||||
'gpus', 'expected_root_gpu', "distributed_backend"], [
|
||||
pytest.param(None, None, None, id="None is None"),
|
||||
pytest.param(None, None, "ddp", id="None is None"),
|
||||
pytest.param(0, None, "ddp", id="None is None"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize([
|
||||
'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_passing)
|
||||
])
|
||||
def test_root_gpu_property_0_passing(
|
||||
mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
|
||||
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu
|
||||
|
||||
|
||||
# Asking for a gpu when non are available will result in a MisconfigurationException
|
||||
test_root_gpu_data_for_0_devices_raising = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize([
|
||||
'gpus', 'expected_root_gpu', "distributed_backend"], [
|
||||
pytest.param(1, None, "ddp"),
|
||||
pytest.param(3, None, "ddp"),
|
||||
pytest.param(3, None, "ddp"),
|
||||
|
@ -323,34 +314,27 @@ test_root_gpu_data_for_0_devices_raising = [
|
|||
pytest.param([0, 1], None, "ddp"),
|
||||
pytest.param(-1, None, "ddp"),
|
||||
pytest.param('-1', None, "ddp")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize([
|
||||
'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_raising)
|
||||
])
|
||||
def test_root_gpu_property_0_raising(
|
||||
mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
|
||||
with pytest.raises(MisconfigurationException):
|
||||
Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu
|
||||
|
||||
|
||||
test_determine_root_gpu_device_data = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], [
|
||||
pytest.param(None, None, id="No gpus, expect gpu root device to be None"),
|
||||
pytest.param([0], 0, id="Oth gpu, expect gpu root device to be 0."),
|
||||
pytest.param([1], 1, id="1st gpu, expect gpu root device to be 1."),
|
||||
pytest.param([3], 3, id="3rd gpu, expect gpu root device to be 3."),
|
||||
pytest.param([1, 2], 1, id="[1, 2] gpus, expect gpu root device to be 1."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], test_determine_root_gpu_device_data)
|
||||
])
|
||||
def test_determine_root_gpu_device(gpus, expected_root_gpu):
|
||||
assert determine_root_gpu_device(gpus) == expected_root_gpu
|
||||
|
||||
|
||||
test_parse_gpu_ids_data = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], [
|
||||
pytest.param(None, None),
|
||||
pytest.param(0, None),
|
||||
pytest.param(1, [0]),
|
||||
|
@ -362,16 +346,13 @@ test_parse_gpu_ids_data = [
|
|||
pytest.param('3', [3]),
|
||||
pytest.param('1, 3', [1, 3]),
|
||||
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], test_parse_gpu_ids_data)
|
||||
])
|
||||
def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
|
||||
assert parse_gpu_ids(gpus) == expected_gpu_ids
|
||||
|
||||
|
||||
test_parse_gpu_invalid_inputs_data = [
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus'], [
|
||||
pytest.param(0.1),
|
||||
pytest.param(-2),
|
||||
pytest.param(False),
|
||||
|
@ -380,11 +361,7 @@ test_parse_gpu_invalid_inputs_data = [
|
|||
pytest.param([None]),
|
||||
pytest.param(['0']),
|
||||
pytest.param((0, 1)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.gpus_param_tests
|
||||
@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data)
|
||||
])
|
||||
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
|
||||
with pytest.raises(MisconfigurationException):
|
||||
parse_gpu_ids(gpus)
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import os
|
||||
|
||||
import tests.models.utils as tutils
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from tests.models import (
|
||||
TestModelBase,
|
||||
LightTrainDataloader,
|
||||
|
|
|
@ -27,6 +27,28 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
|
||||
def test_hparams_save_load(tmpdir):
|
||||
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})
|
||||
|
||||
# logger file to get meta
|
||||
trainer_options = dict(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=2,
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
|
||||
# try to load the model now
|
||||
pretrained_model = tutils.load_model_from_checkpoint(
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
module_class=DictHparamsModel
|
||||
)
|
||||
|
||||
|
||||
def test_no_val_module(tmpdir):
|
||||
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
|
||||
tutils.reset_seed()
|
||||
|
@ -126,7 +148,8 @@ def test_gradient_accumulation_scheduling(tmpdir):
|
|||
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})
|
||||
|
||||
# test optimizer call freq matches scheduler
|
||||
def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
||||
def _optimizer_step(self, epoch, batch_idx, optimizer,
|
||||
optimizer_idx, second_order_closure=None):
|
||||
# only test the first 12 batches in epoch
|
||||
if batch_idx < 12:
|
||||
if epoch == 0:
|
||||
|
@ -255,11 +278,11 @@ def test_model_checkpoint_options(tmp_path):
|
|||
assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1"
|
||||
|
||||
# verify correct naming
|
||||
for fname in {'_epoch=4_val_loss=2.50.ckpt',
|
||||
'_epoch=3_val_loss=5.00.ckpt',
|
||||
'_epoch=2_val_loss=2.80.ckpt',
|
||||
'_epoch=1_val_loss=9.00.ckpt',
|
||||
'_epoch=0_val_loss=10.00.ckpt'}:
|
||||
for fname in {'epoch=4.ckpt',
|
||||
'epoch=3.ckpt',
|
||||
'epoch=2.ckpt',
|
||||
'epoch=1.ckpt',
|
||||
'epoch=0.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
save_dir = tmp_path / "2"
|
||||
|
@ -286,7 +309,7 @@ def test_model_checkpoint_options(tmp_path):
|
|||
|
||||
# -----------------
|
||||
# CASE K=1 (2.5, epoch 4)
|
||||
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
|
||||
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix_')
|
||||
checkpoint_callback.save_function = mock_save_function
|
||||
trainer = Trainer()
|
||||
|
||||
|
@ -299,7 +322,7 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
||||
assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists
|
||||
assert 'test_prefix_epoch=4.ckpt' in file_lists
|
||||
|
||||
save_dir = tmp_path / "4"
|
||||
save_dir.mkdir()
|
||||
|
@ -322,8 +345,8 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
|
||||
for fname in {'_epoch=4_val_loss=2.50.ckpt',
|
||||
'_epoch=2_val_loss=2.80.ckpt',
|
||||
for fname in {'epoch=4.ckpt',
|
||||
'epoch=2.ckpt',
|
||||
'other_file.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
|
@ -368,9 +391,9 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
|
||||
for fname in {'_epoch=0_val_loss=2.80.ckpt',
|
||||
'_epoch=0_val_loss=2.50.ckpt',
|
||||
'_epoch=0_val_loss=5.00.ckpt'}:
|
||||
for fname in {'epoch=0.ckpt',
|
||||
'epoch=0.ckpt',
|
||||
'epoch=0.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
|
||||
|
@ -620,25 +643,3 @@ def test_default_args(tmpdir):
|
|||
|
||||
assert isinstance(trainer, Trainer)
|
||||
assert trainer.max_epochs == 5
|
||||
|
||||
|
||||
def test_hparams_save_load(tmpdir):
|
||||
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})
|
||||
|
||||
# logger file to get meta
|
||||
trainer_options = dict(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=2,
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
|
||||
# try to load the model now
|
||||
pretrained_model = tutils.load_model_from_checkpoint(
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
module_class=DictHparamsModel
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue