proper checkpoint implementation (#1043)

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* name formatting

* version

* testing

* add test

* fix test

* Update model_checkpoint.py

* doctests

* pylint

* tests

* debug

* debug

* enabled early stopping/checkpooiunt even  without val step

* fix MNIST download (#1044)

* fix MNIST download

* simple

* name formatting

* version

* testing

* add test

* fix test

* doctests

* tests

* debug

* debug

* rebased 1041

* rebased 1041

* tests

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-03-04 23:02:19 -05:00 committed by GitHub
parent 165b9fb3f3
commit bcb45d906d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 207 additions and 193 deletions

View File

@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
### Changed

View File

@ -1,18 +1,12 @@
r"""
Model Checkpoint
==============
Save the model as often as requested.
"""
import os
import glob
import shutil
import logging as log
import warnings
import re
import numpy as np
from .base import Callback
from pytorch_lightning.callbacks.base import Callback
class ModelCheckpoint(Callback):
@ -20,21 +14,23 @@ class ModelCheckpoint(Callback):
Save the model after every epoch.
Args:
dirpath: path to save the model file.
filepath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# no path
ModelCheckpoint()
# saves like /my/path/epoch_0.ckpt
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
# save any arbitrary metrics like and val_loss, etc in name
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
@ -43,7 +39,7 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode: one of {auto, min, max}.
mode (str): one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
@ -51,46 +47,43 @@ class ModelCheckpoint(Callback):
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only: if True, then only the model's weights will be
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period: Interval (number of epochs) between checkpoints.
prefix: String name for particular model
period (int): Interval (number of epochs) between checkpoints.
Example:
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint('my_path')
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""
#: checkpoint extension
EXTENSION = '.ckpt'
def __init__(
self,
dirpath: str,
monitor: str = 'val_loss',
verbose: bool = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = 'auto',
period: int = 1,
prefix: str = ''
):
# save epoch and val_loss in name
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
"""
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
warnings.warn(
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
if os.path.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
@ -102,14 +95,6 @@ class ModelCheckpoint(Callback):
self.best = 0
self.save_function = None
# this create unique prefix if the give already exists
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
version_cnt = 0
while self.prefix in existing_names:
self.prefix = f'{prefix}-v{version_cnt}'
version_cnt += 1
mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
@ -125,39 +110,65 @@ class ModelCheckpoint(Callback):
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
def _del_model(self, filepath: str) -> None:
# shutil.rmtree(filepath)
def _del_model(self, filepath):
os.remove(filepath)
def _save_model(self, filepath: str) -> None:
def _save_model(self, filepath):
# make paths
os.makedirs(self.dirpath, exist_ok=True)
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError("Method `.save_function()` not set")
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current: float) -> bool:
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according define template.
Examples
--------
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
"""
# check if user passed in keys to the string
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
if len(groups) == 0:
# default name
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
else:
metrics['epoch'] = epoch
filename = self.filename
for tmp in groups:
name = tmp[1:]
filename = filename.replace(tmp, name + '={' + name)
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
str_ver = f'_v{ver}' if ver is not None else ''
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath
def on_validation_end(self, trainer, pl_module) -> None:
# only run on main process
if trainer.proc_rank != 0:
return
logs = trainer.callback_metrics
def on_validation_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1
@ -166,27 +177,36 @@ class ModelCheckpoint(Callback):
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
current = logs.get(self.monitor)
filepath = self._get_available_filepath(current, epoch)
filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
if self.save_top_k != -1:
current = metrics.get(self.monitor)
if current is None:
warnings.warn(f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info('Epoch %05d: saving model to %s', epoch, filepath)
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
def _do_check_save(self, filepath, current, epoch):
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
@ -205,6 +225,8 @@ class ModelCheckpoint(Callback):
self.best = _op(self.best_k_models.values())
if self.verbose > 0:
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)

View File

@ -68,19 +68,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
#: True if using amp
self.use_amp = False
@property
def hparams(self) -> Namespace:
if not hasattr(self, '_hparams'):
return Namespace()
assert isinstance(self._hparams, dict)
return Namespace(**self._hparams)
@hparams.setter
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Set the model hyper-parameters."""
if isinstance(params, Namespace):
params = vars(params)
self._hparams = params
self.hparams = None
def print(self, *args, **kwargs):
r"""

View File

@ -46,6 +46,10 @@ class LightningLoggerBase(ABC):
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
@abstractmethod

View File

@ -48,9 +48,15 @@ class TrainerCallbackConfigMixin(ABC):
else:
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overriden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'
self.ckpt_path = ckpt_path
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
dirpath=ckpt_path
filepath=ckpt_path,
monitor=monitor_key
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None

View File

@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC):
process_output: ...
training_tqdm_dict: ...
proc_rank: int
checkpoint_callback: ...
current_epoch: int
callback_metrics: ...
test_dataloaders: DataLoader
@ -377,11 +376,6 @@ class TrainerEvaluationLoopMixin(ABC):
# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
# model checkpointing
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional

View File

@ -1132,9 +1132,6 @@ class Trainer(TrainerIOMixin,
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
# set up checkpoint callback
self.configure_checkpoint_callback()
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()
@ -1151,6 +1148,9 @@ class Trainer(TrainerIOMixin,
# if cluster resets state, the model will update with the saved weights
self.model = model
# set up checkpoint callback
self.configure_checkpoint_callback()
# restore training and model before hpc call
self.restore_weights(model)

View File

@ -165,14 +165,15 @@ class TrainerIOMixin(ABC):
def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']
if self.proc_rank == 0:
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']
self._atomic_save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)
def restore(self, checkpoint_path, on_gpu):
"""

View File

@ -203,6 +203,7 @@ class TrainerTrainLoopMixin(ABC):
max_steps: int
max_steps: int
total_batch_idx: int
checkpoint_callback: ...
# Callback system
callbacks: List[Callback]
@ -212,6 +213,7 @@ class TrainerTrainLoopMixin(ABC):
on_batch_end: Callable
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
@property
def max_nb_epochs(self):
@ -454,9 +456,6 @@ class TrainerTrainLoopMixin(ABC):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)
# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
@ -469,6 +468,17 @@ class TrainerTrainLoopMixin(ABC):
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)
# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# save checkpoint even when no test or val step are defined
train_step_only = not self.is_overriden('validation_step')
if self.fast_dev_run or should_check_val or train_step_only:
self.call_checkpoint_callback()
if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
@ -705,3 +715,8 @@ class TrainerTrainLoopMixin(ABC):
output = self.process_output(output, train=True)
return output
def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

View File

@ -46,6 +46,7 @@ class DictHparamsModel(LightningModule):
def __init__(self, hparams: Dict):
super(DictHparamsModel, self).__init__()
self.hparams = hparams
self.l1 = torch.nn.Linear(hparams.get('in_features'), hparams['out_features'])
def forward(self, x):

View File

@ -239,5 +239,6 @@ def set_random_master_port():
def init_checkpoint_callback(logger, path_dir=None):
exp_path = get_data_path(logger, path_dir=path_dir)
ckpt_dir = os.path.join(exp_path, 'checkpoints')
os.mkdir(ckpt_dir)
checkpoint = ModelCheckpoint(ckpt_dir)
return checkpoint

View File

@ -256,66 +256,57 @@ def mocked_device_count_0(monkeypatch):
monkeypatch.setattr(torch.cuda, 'device_count', device_count)
test_num_gpus_data = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [
pytest.param(None, 0, None, id="None - expect 0 gpu to use."),
pytest.param(0, 0, None, id="Oth gpu, expect 1 gpu to use."),
pytest.param(1, 1, None, id="1st gpu, expect 1 gpu to use."),
pytest.param(-1, PRETEND_N_OF_GPUS, "ddp", id="-1 - use all gpus"),
pytest.param('-1', PRETEND_N_OF_GPUS, "ddp", id="'-1' - use all gpus"),
pytest.param(3, 3, "ddp", id="3rd gpu - 1 gpu to use (backend:ddp)")
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data)
])
def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distributed_backend):
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus
test_num_gpus_data_0 = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [
pytest.param(None, 0, None, id="None - expect 0 gpu to use."),
pytest.param(None, 0, "ddp", id="None - expect 0 gpu to use."),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data_0)
])
def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distributed_backend):
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus
test_root_gpu_data = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], [
pytest.param(None, None, "ddp", id="None is None"),
pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."),
pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."),
pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."),
pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."),
pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data)
pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")
])
def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend):
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu
test_root_gpu_data_for_0_devices_passing = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize([
'gpus', 'expected_root_gpu', "distributed_backend"], [
pytest.param(None, None, None, id="None is None"),
pytest.param(None, None, "ddp", id="None is None"),
pytest.param(0, None, "ddp", id="None is None"),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize([
'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_passing)
])
def test_root_gpu_property_0_passing(
mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu
# Asking for a gpu when non are available will result in a MisconfigurationException
test_root_gpu_data_for_0_devices_raising = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize([
'gpus', 'expected_root_gpu', "distributed_backend"], [
pytest.param(1, None, "ddp"),
pytest.param(3, None, "ddp"),
pytest.param(3, None, "ddp"),
@ -323,34 +314,27 @@ test_root_gpu_data_for_0_devices_raising = [
pytest.param([0, 1], None, "ddp"),
pytest.param(-1, None, "ddp"),
pytest.param('-1', None, "ddp")
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize([
'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_raising)
])
def test_root_gpu_property_0_raising(
mocked_device_count_0, gpus, expected_root_gpu, distributed_backend):
with pytest.raises(MisconfigurationException):
Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu
test_determine_root_gpu_device_data = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], [
pytest.param(None, None, id="No gpus, expect gpu root device to be None"),
pytest.param([0], 0, id="Oth gpu, expect gpu root device to be 0."),
pytest.param([1], 1, id="1st gpu, expect gpu root device to be 1."),
pytest.param([3], 3, id="3rd gpu, expect gpu root device to be 3."),
pytest.param([1, 2], 1, id="[1, 2] gpus, expect gpu root device to be 1."),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], test_determine_root_gpu_device_data)
])
def test_determine_root_gpu_device(gpus, expected_root_gpu):
assert determine_root_gpu_device(gpus) == expected_root_gpu
test_parse_gpu_ids_data = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], [
pytest.param(None, None),
pytest.param(0, None),
pytest.param(1, [0]),
@ -362,16 +346,13 @@ test_parse_gpu_ids_data = [
pytest.param('3', [3]),
pytest.param('1, 3', [1, 3]),
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], test_parse_gpu_ids_data)
])
def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
assert parse_gpu_ids(gpus) == expected_gpu_ids
test_parse_gpu_invalid_inputs_data = [
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus'], [
pytest.param(0.1),
pytest.param(-2),
pytest.param(False),
@ -380,11 +361,7 @@ test_parse_gpu_invalid_inputs_data = [
pytest.param([None]),
pytest.param(['0']),
pytest.param((0, 1)),
]
@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data)
])
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)

View File

@ -1,5 +1,8 @@
import os
import tests.models.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.models import (
TestModelBase,
LightTrainDataloader,

View File

@ -27,6 +27,28 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
def test_hparams_save_load(tmpdir):
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1
# try to load the model now
pretrained_model = tutils.load_model_from_checkpoint(
trainer.checkpoint_callback.dirpath,
module_class=DictHparamsModel
)
def test_no_val_module(tmpdir):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
tutils.reset_seed()
@ -126,7 +148,8 @@ def test_gradient_accumulation_scheduling(tmpdir):
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})
# test optimizer call freq matches scheduler
def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
def _optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, second_order_closure=None):
# only test the first 12 batches in epoch
if batch_idx < 12:
if epoch == 0:
@ -255,11 +278,11 @@ def test_model_checkpoint_options(tmp_path):
assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1"
# verify correct naming
for fname in {'_epoch=4_val_loss=2.50.ckpt',
'_epoch=3_val_loss=5.00.ckpt',
'_epoch=2_val_loss=2.80.ckpt',
'_epoch=1_val_loss=9.00.ckpt',
'_epoch=0_val_loss=10.00.ckpt'}:
for fname in {'epoch=4.ckpt',
'epoch=3.ckpt',
'epoch=2.ckpt',
'epoch=1.ckpt',
'epoch=0.ckpt'}:
assert fname in file_lists
save_dir = tmp_path / "2"
@ -286,7 +309,7 @@ def test_model_checkpoint_options(tmp_path):
# -----------------
# CASE K=1 (2.5, epoch 4)
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix_')
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
@ -299,7 +322,7 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists
assert 'test_prefix_epoch=4.ckpt' in file_lists
save_dir = tmp_path / "4"
save_dir.mkdir()
@ -322,8 +345,8 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
for fname in {'_epoch=4_val_loss=2.50.ckpt',
'_epoch=2_val_loss=2.80.ckpt',
for fname in {'epoch=4.ckpt',
'epoch=2.ckpt',
'other_file.ckpt'}:
assert fname in file_lists
@ -368,9 +391,9 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
for fname in {'_epoch=0_val_loss=2.80.ckpt',
'_epoch=0_val_loss=2.50.ckpt',
'_epoch=0_val_loss=5.00.ckpt'}:
for fname in {'epoch=0.ckpt',
'epoch=0.ckpt',
'epoch=0.ckpt'}:
assert fname in file_lists
@ -620,25 +643,3 @@ def test_default_args(tmpdir):
assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5
def test_hparams_save_load(tmpdir):
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1
# try to load the model now
pretrained_model = tutils.load_model_from_checkpoint(
trainer.checkpoint_callback.dirpath,
module_class=DictHparamsModel
)