2021-04-06 13:19:11 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-04-28 08:34:32 +00:00
|
|
|
import inspect
|
2021-04-06 13:19:11 +00:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import sys
|
|
|
|
from argparse import Namespace
|
2022-05-03 12:16:37 +00:00
|
|
|
from contextlib import contextmanager, ExitStack, redirect_stdout
|
2021-04-28 08:34:32 +00:00
|
|
|
from io import StringIO
|
2022-06-01 09:00:57 +00:00
|
|
|
from typing import Callable, List, Optional, Union
|
2021-04-06 13:19:11 +00:00
|
|
|
from unittest import mock
|
2021-09-22 14:19:02 +00:00
|
|
|
from unittest.mock import ANY
|
2021-04-06 13:19:11 +00:00
|
|
|
|
|
|
|
import pytest
|
2021-06-04 05:43:43 +00:00
|
|
|
import torch
|
2021-04-06 13:19:11 +00:00
|
|
|
import yaml
|
2021-06-04 05:43:43 +00:00
|
|
|
from packaging import version
|
2021-12-01 15:41:22 +00:00
|
|
|
from torch.optim import SGD
|
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2022-05-31 20:31:25 +00:00
|
|
|
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
2021-04-06 13:19:11 +00:00
|
|
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
2022-06-14 23:53:54 +00:00
|
|
|
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
2022-06-21 21:58:41 +00:00
|
|
|
from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger
|
2021-05-04 06:42:57 +00:00
|
|
|
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
2022-06-21 21:58:41 +00:00
|
|
|
from pytorch_lightning.profiler import PyTorchProfiler
|
|
|
|
from pytorch_lightning.strategies import DDPStrategy
|
2021-08-28 04:43:14 +00:00
|
|
|
from pytorch_lightning.trainer.states import TrainerFn
|
2021-04-06 13:19:11 +00:00
|
|
|
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
2021-09-17 17:00:46 +00:00
|
|
|
from pytorch_lightning.utilities.cli import (
|
2022-06-21 21:58:41 +00:00
|
|
|
_JSONARGPARSE_SIGNATURES_AVAILABLE,
|
2021-09-17 17:00:46 +00:00
|
|
|
instantiate_class,
|
|
|
|
LightningArgumentParser,
|
|
|
|
LightningCLI,
|
2022-05-03 12:16:37 +00:00
|
|
|
LRSchedulerTypeTuple,
|
2021-09-17 17:00:46 +00:00
|
|
|
SaveConfigCallback,
|
|
|
|
)
|
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2022-03-27 21:31:20 +00:00
|
|
|
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
2022-06-15 22:10:49 +00:00
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
from tests_pytorch.helpers.utils import no_warning_call
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
torchvision_version = version.parse("0")
|
2021-06-04 05:43:43 +00:00
|
|
|
if _TORCHVISION_AVAILABLE:
|
2021-07-26 11:37:35 +00:00
|
|
|
torchvision_version = version.parse(__import__("torchvision").__version__)
|
2021-06-04 05:43:43 +00:00
|
|
|
|
2022-06-21 21:58:41 +00:00
|
|
|
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
|
|
|
from jsonargparse import lazy_instance
|
|
|
|
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
@contextmanager
|
|
|
|
def mock_subclasses(baseclass, *subclasses):
|
|
|
|
"""Mocks baseclass so that it only has the given child subclasses."""
|
|
|
|
with ExitStack() as stack:
|
|
|
|
mgr = mock.patch.object(baseclass, "__subclasses__", return_value=[*subclasses])
|
|
|
|
stack.enter_context(mgr)
|
|
|
|
for mgr in [mock.patch.object(s, "__subclasses__", return_value=[]) for s in subclasses]:
|
|
|
|
stack.enter_context(mgr)
|
|
|
|
yield None
|
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@mock.patch("argparse.ArgumentParser.parse_args")
|
2021-11-29 14:12:53 +00:00
|
|
|
def test_default_args(mock_argparse):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Tests default argument parser for Trainer."""
|
2021-04-06 13:19:11 +00:00
|
|
|
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
|
|
|
|
|
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
args = parser.parse_args([])
|
|
|
|
|
|
|
|
args.max_epochs = 5
|
|
|
|
trainer = Trainer.from_argparse_args(args)
|
|
|
|
|
|
|
|
assert isinstance(trainer, Trainer)
|
|
|
|
assert trainer.max_epochs == 5
|
|
|
|
|
|
|
|
|
2022-02-28 22:45:26 +00:00
|
|
|
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
|
2021-04-06 13:19:11 +00:00
|
|
|
def test_add_argparse_args_redefined(cli_args):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
|
2021-04-06 13:19:11 +00:00
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
parser.add_lightning_class_args(Trainer, None)
|
|
|
|
|
|
|
|
args = parser.parse_args(cli_args)
|
|
|
|
|
|
|
|
# make sure we can pickle args
|
|
|
|
pickle.dumps(args)
|
|
|
|
|
|
|
|
# Check few deprecated args are not in namespace:
|
2021-07-26 11:37:35 +00:00
|
|
|
for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
|
2021-04-06 13:19:11 +00:00
|
|
|
assert depr_name not in args
|
|
|
|
|
|
|
|
trainer = Trainer.from_argparse_args(args=args)
|
|
|
|
pickle.dumps(trainer)
|
|
|
|
|
|
|
|
assert isinstance(trainer, Trainer)
|
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
|
2021-04-06 13:19:11 +00:00
|
|
|
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
|
|
|
"""Asserts error raised in case of passing not default cli arguments."""
|
|
|
|
|
|
|
|
class _UnkArgError(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _raise():
|
|
|
|
raise _UnkArgError
|
|
|
|
|
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
parser.add_lightning_class_args(Trainer, None)
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)
|
2021-04-06 13:19:11 +00:00
|
|
|
|
|
|
|
with pytest.raises(_UnkArgError):
|
|
|
|
parser.parse_args(cli_args)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2021-07-26 11:37:35 +00:00
|
|
|
["cli_args", "expected"],
|
2021-04-06 13:19:11 +00:00
|
|
|
[
|
2021-07-26 11:37:35 +00:00
|
|
|
("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")),
|
2021-04-06 13:19:11 +00:00
|
|
|
(
|
2021-07-26 11:37:35 +00:00
|
|
|
"--auto_lr_find any_string --auto_scale_batch_size ON",
|
|
|
|
dict(auto_lr_find="any_string", auto_scale_batch_size=True),
|
2021-04-06 13:19:11 +00:00
|
|
|
),
|
2021-07-26 11:37:35 +00:00
|
|
|
("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)),
|
|
|
|
("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)),
|
|
|
|
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)),
|
|
|
|
("--tpu_cores=8", dict(tpu_cores=8)),
|
|
|
|
("--tpu_cores=1,", dict(tpu_cores="1,")),
|
|
|
|
("--limit_train_batches=100", dict(limit_train_batches=100)),
|
|
|
|
("--limit_train_batches 0.8", dict(limit_train_batches=0.8)),
|
2021-10-13 11:50:54 +00:00
|
|
|
("--enable_model_summary FALSE", dict(enable_model_summary=False)),
|
2021-04-06 13:19:11 +00:00
|
|
|
(
|
|
|
|
"",
|
|
|
|
dict(
|
|
|
|
# These parameters are marked as Optional[...] in Trainer.__init__,
|
|
|
|
# with None as default. They should not be changed by the argparse
|
|
|
|
# interface.
|
|
|
|
min_steps=None,
|
2021-10-19 13:54:37 +00:00
|
|
|
accelerator=None,
|
2021-07-26 11:37:35 +00:00
|
|
|
profiler=None,
|
2021-04-06 13:19:11 +00:00
|
|
|
),
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_parse_args_parsing(cli_args, expected):
|
|
|
|
"""Test parsing simple types and None optionals not modified."""
|
2021-07-26 11:37:35 +00:00
|
|
|
cli_args = cli_args.split(" ") if cli_args else []
|
2021-04-06 13:19:11 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-09-16 13:04:51 +00:00
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
parser.add_lightning_class_args(Trainer, None)
|
2021-04-06 13:19:11 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
for k, v in expected.items():
|
|
|
|
assert getattr(args, k) == v
|
2021-07-26 11:37:35 +00:00
|
|
|
if "tpu_cores" not in expected or _TPU_AVAILABLE:
|
2021-04-06 13:19:11 +00:00
|
|
|
assert Trainer.from_argparse_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2021-07-26 11:37:35 +00:00
|
|
|
["cli_args", "expected", "instantiate"],
|
2021-04-06 13:19:11 +00:00
|
|
|
[
|
2021-07-26 11:37:35 +00:00
|
|
|
(["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False),
|
|
|
|
(["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False),
|
|
|
|
(['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True),
|
2021-04-06 13:19:11 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
|
|
|
|
"""Test parsing complex types."""
|
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-09-16 13:04:51 +00:00
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
parser.add_lightning_class_args(Trainer, None)
|
2021-04-06 13:19:11 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
for k, v in expected.items():
|
|
|
|
assert getattr(args, k) == v
|
|
|
|
if instantiate:
|
|
|
|
assert Trainer.from_argparse_args(args)
|
|
|
|
|
|
|
|
|
2022-03-23 22:18:30 +00:00
|
|
|
@pytest.mark.parametrize(
|
2022-04-10 17:10:05 +00:00
|
|
|
["cli_args", "expected_gpu"],
|
|
|
|
[
|
|
|
|
("--accelerator gpu --devices 1", [0]),
|
|
|
|
("--accelerator gpu --devices 0,", [0]),
|
|
|
|
("--accelerator gpu --devices 1,", [1]),
|
|
|
|
("--accelerator gpu --devices 0,1", [0, 1]),
|
|
|
|
],
|
2022-03-23 22:18:30 +00:00
|
|
|
)
|
2021-04-06 13:19:11 +00:00
|
|
|
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
|
|
|
|
"""Test parsing of gpus and instantiation of Trainer."""
|
|
|
|
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
|
2022-02-09 23:11:27 +00:00
|
|
|
monkeypatch.setattr("torch.cuda.is_available", lambda: True)
|
2021-07-26 11:37:35 +00:00
|
|
|
cli_args = cli_args.split(" ") if cli_args else []
|
2021-04-06 13:19:11 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-09-16 13:04:51 +00:00
|
|
|
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
|
|
parser.add_lightning_class_args(Trainer, None)
|
2021-04-06 13:19:11 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
trainer = Trainer.from_argparse_args(args)
|
2022-03-23 22:18:30 +00:00
|
|
|
assert trainer.device_ids == expected_gpu
|
2021-04-06 13:19:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
sys.version_info < (3, 7),
|
|
|
|
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec",
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
2021-07-26 11:37:35 +00:00
|
|
|
["cli_args", "extra_args"],
|
2021-04-06 13:19:11 +00:00
|
|
|
[
|
|
|
|
({}, {}),
|
|
|
|
(dict(logger=False), {}),
|
|
|
|
(dict(logger=False), dict(logger=True)),
|
2021-10-12 07:55:07 +00:00
|
|
|
(dict(logger=False), dict(enable_checkpointing=True)),
|
2021-04-06 13:19:11 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_init_from_argparse_args(cli_args, extra_args):
|
|
|
|
unknown_args = dict(unknown_arg=0)
|
|
|
|
|
2022-02-17 01:27:51 +00:00
|
|
|
# unknown args in the argparser/namespace should be ignored
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
|
2021-04-06 13:19:11 +00:00
|
|
|
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
|
|
|
|
expected = dict(cli_args)
|
|
|
|
expected.update(extra_args) # extra args should override any cli arg
|
|
|
|
init.assert_called_with(trainer, **expected)
|
|
|
|
|
|
|
|
# passing in unknown manual args should throw an error
|
|
|
|
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
|
|
|
|
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)
|
|
|
|
|
|
|
|
|
2021-07-26 08:53:48 +00:00
|
|
|
class Model(LightningModule):
|
|
|
|
def __init__(self, model_param: int):
|
|
|
|
super().__init__()
|
|
|
|
self.model_param = model_param
|
|
|
|
|
|
|
|
|
|
|
|
def _model_builder(model_param: int) -> Model:
|
|
|
|
return Model(model_param)
|
|
|
|
|
|
|
|
|
|
|
|
def _trainer_builder(
|
2021-07-26 11:37:35 +00:00
|
|
|
limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None
|
2021-07-26 08:53:48 +00:00
|
|
|
) -> Trainer:
|
|
|
|
return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks)
|
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (_trainer_builder, _model_builder)])
|
2021-07-26 08:53:48 +00:00
|
|
|
def test_lightning_cli(trainer_class, model_class, monkeypatch):
|
2021-04-06 13:19:11 +00:00
|
|
|
"""Test that LightningCLI correctly instantiates model, trainer and calls fit."""
|
|
|
|
|
2021-07-26 08:53:48 +00:00
|
|
|
expected_model = dict(model_param=7)
|
|
|
|
expected_trainer = dict(limit_train_batches=100)
|
|
|
|
|
2021-04-06 13:19:11 +00:00
|
|
|
def fit(trainer, model):
|
2021-07-26 08:53:48 +00:00
|
|
|
for k, v in expected_model.items():
|
2021-04-06 13:19:11 +00:00
|
|
|
assert getattr(model, k) == v
|
2021-07-26 08:53:48 +00:00
|
|
|
for k, v in expected_trainer.items():
|
2021-04-06 13:19:11 +00:00
|
|
|
assert getattr(trainer, k) == v
|
|
|
|
save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
|
|
|
|
assert len(save_callback) == 1
|
|
|
|
save_callback[0].on_train_start(trainer, model)
|
|
|
|
|
2021-07-26 08:53:48 +00:00
|
|
|
def on_train_start(callback, trainer, _):
|
2021-04-06 13:19:11 +00:00
|
|
|
config_dump = callback.parser.dump(callback.config, skip_none=False)
|
2021-07-26 08:53:48 +00:00
|
|
|
for k, v in expected_model.items():
|
2021-07-26 11:37:35 +00:00
|
|
|
assert f" {k}: {v}" in config_dump
|
2021-07-26 08:53:48 +00:00
|
|
|
for k, v in expected_trainer.items():
|
2021-07-26 11:37:35 +00:00
|
|
|
assert f" {k}: {v}" in config_dump
|
2021-04-06 13:19:11 +00:00
|
|
|
trainer.ran_asserts = True
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
monkeypatch.setattr(Trainer, "fit", fit)
|
|
|
|
monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
|
2021-07-26 08:53:48 +00:00
|
|
|
cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
|
2021-07-26 11:37:35 +00:00
|
|
|
assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
|
2021-04-06 13:19:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_args_callbacks(tmpdir):
|
|
|
|
|
|
|
|
callbacks = [
|
|
|
|
dict(
|
2021-07-26 11:37:35 +00:00
|
|
|
class_path="pytorch_lightning.callbacks.LearningRateMonitor",
|
|
|
|
init_args=dict(logging_interval="epoch", log_momentum=True),
|
2021-04-06 13:19:11 +00:00
|
|
|
),
|
2021-07-26 11:37:35 +00:00
|
|
|
dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")),
|
2021-04-06 13:19:11 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def on_fit_start(self):
|
|
|
|
callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
|
|
|
|
assert len(callback) == 1
|
2021-07-26 11:37:35 +00:00
|
|
|
assert callback[0].logging_interval == "epoch"
|
2021-04-06 13:19:11 +00:00
|
|
|
assert callback[0].log_momentum is True
|
|
|
|
callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
|
|
|
|
assert len(callback) == 1
|
2021-07-26 11:37:35 +00:00
|
|
|
assert callback[0].monitor == "NAME"
|
2021-04-06 13:19:11 +00:00
|
|
|
self.trainer.ran_asserts = True
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]):
|
2021-04-06 13:19:11 +00:00
|
|
|
cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
|
|
|
|
|
|
|
|
assert cli.trainer.ran_asserts
|
|
|
|
|
|
|
|
|
2022-06-02 01:00:02 +00:00
|
|
|
def test_lightning_cli_single_arg_callback():
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--trainer.callbacks=DeviceStatsMonitor"]):
|
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
|
|
|
|
assert cli.config.trainer.callbacks.class_path == "pytorch_lightning.callbacks.DeviceStatsMonitor"
|
|
|
|
assert not isinstance(cli.config_init.trainer, list)
|
|
|
|
|
|
|
|
|
2021-08-30 15:44:18 +00:00
|
|
|
@pytest.mark.parametrize("run", (False, True))
|
|
|
|
def test_lightning_cli_configurable_callbacks(tmpdir, run):
|
2021-06-16 00:03:37 +00:00
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
2021-07-26 11:37:35 +00:00
|
|
|
parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
|
2021-06-16 00:03:37 +00:00
|
|
|
|
2021-08-30 15:44:18 +00:00
|
|
|
def fit(self, **_):
|
|
|
|
pass
|
|
|
|
|
|
|
|
cli_args = ["fit"] if run else []
|
|
|
|
cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
|
2021-06-16 00:03:37 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-08-30 15:44:18 +00:00
|
|
|
cli = MyLightningCLI(BoringModel, run=run)
|
2021-06-16 00:03:37 +00:00
|
|
|
|
|
|
|
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
|
|
|
|
assert len(callback) == 1
|
2021-07-26 11:37:35 +00:00
|
|
|
assert callback[0].logging_interval == "epoch"
|
2021-06-16 00:03:37 +00:00
|
|
|
|
|
|
|
|
2021-05-04 06:42:57 +00:00
|
|
|
def test_lightning_cli_args_cluster_environments(tmpdir):
|
2021-07-26 11:37:35 +00:00
|
|
|
plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")]
|
2021-05-04 06:42:57 +00:00
|
|
|
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def on_fit_start(self):
|
|
|
|
# Ensure SLURMEnvironment is set, instead of default LightningEnvironment
|
2022-02-17 23:38:39 +00:00
|
|
|
assert isinstance(self.trainer._accelerator_connector.cluster_environment, SLURMEnvironment)
|
2021-05-04 06:42:57 +00:00
|
|
|
self.trainer.ran_asserts = True
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.plugins={json.dumps(plugins)}"]):
|
2021-05-04 06:42:57 +00:00
|
|
|
cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
|
|
|
|
|
|
|
|
assert cli.trainer.ran_asserts
|
|
|
|
|
|
|
|
|
2021-04-06 13:19:11 +00:00
|
|
|
def test_lightning_cli_args(tmpdir):
|
|
|
|
|
|
|
|
cli_args = [
|
2021-08-28 04:43:14 +00:00
|
|
|
"fit",
|
2021-07-26 11:37:35 +00:00
|
|
|
f"--data.data_dir={tmpdir}",
|
|
|
|
f"--trainer.default_root_dir={tmpdir}",
|
|
|
|
"--trainer.max_epochs=1",
|
2021-10-13 11:50:54 +00:00
|
|
|
"--trainer.enable_model_summary=False",
|
2021-07-26 11:37:35 +00:00
|
|
|
"--seed_everything=1234",
|
2021-04-06 13:19:11 +00:00
|
|
|
]
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
|
|
|
cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]})
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
|
2021-04-06 13:19:11 +00:00
|
|
|
assert os.path.isfile(config_path)
|
|
|
|
with open(config_path) as f:
|
2021-08-28 04:43:14 +00:00
|
|
|
loaded_config = yaml.safe_load(f.read())
|
|
|
|
|
2021-11-19 17:03:14 +00:00
|
|
|
cli_config = cli.config["fit"].as_dict()
|
2021-08-28 04:43:14 +00:00
|
|
|
assert cli_config["seed_everything"] == 1234
|
|
|
|
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
|
|
|
|
assert loaded_config["data"] == cli_config["data"]
|
|
|
|
assert loaded_config["trainer"] == cli_config["trainer"]
|
2021-04-06 13:19:11 +00:00
|
|
|
|
|
|
|
|
2021-06-15 21:26:39 +00:00
|
|
|
def test_lightning_cli_save_config_cases(tmpdir):
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "config.yaml"
|
2021-08-28 04:43:14 +00:00
|
|
|
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"]
|
2021-06-15 21:26:39 +00:00
|
|
|
|
|
|
|
# With fast_dev_run!=False config should not be saved
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-06-15 21:26:39 +00:00
|
|
|
LightningCLI(BoringModel)
|
|
|
|
assert not os.path.isfile(config_path)
|
|
|
|
|
|
|
|
# With fast_dev_run==False config should be saved
|
2021-07-26 11:37:35 +00:00
|
|
|
cli_args[-1] = "--trainer.max_epochs=1"
|
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-06-15 21:26:39 +00:00
|
|
|
LightningCLI(BoringModel)
|
|
|
|
assert os.path.isfile(config_path)
|
|
|
|
|
|
|
|
# If run again on same directory exception should be raised since config file already exists
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
|
2021-06-15 21:26:39 +00:00
|
|
|
LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
|
2021-04-06 13:19:11 +00:00
|
|
|
def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
2021-08-28 04:43:14 +00:00
|
|
|
input_config = {
|
|
|
|
"fit": {
|
2022-06-14 23:53:54 +00:00
|
|
|
"model": {"class_path": "pytorch_lightning.demos.boring_classes.BoringModel"},
|
|
|
|
"data": {
|
|
|
|
"class_path": "pytorch_lightning.demos.boring_classes.BoringDataModule",
|
|
|
|
"init_args": {"data_dir": str(tmpdir)},
|
|
|
|
},
|
2021-10-15 23:58:07 +00:00
|
|
|
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
|
2021-08-28 04:43:14 +00:00
|
|
|
}
|
|
|
|
}
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "config.yaml"
|
|
|
|
with open(config_path, "w") as f:
|
2021-08-28 04:43:14 +00:00
|
|
|
f.write(yaml.dump(input_config))
|
2021-04-06 13:19:11 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]):
|
2021-04-06 13:19:11 +00:00
|
|
|
cli = LightningCLI(
|
|
|
|
BoringModel,
|
|
|
|
BoringDataModule,
|
|
|
|
subclass_mode_model=True,
|
|
|
|
subclass_mode_data=True,
|
2021-07-26 11:37:35 +00:00
|
|
|
trainer_defaults={"callbacks": LearningRateMonitor()},
|
2021-04-06 13:19:11 +00:00
|
|
|
)
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
|
2021-04-06 13:19:11 +00:00
|
|
|
assert os.path.isfile(config_path)
|
|
|
|
with open(config_path) as f:
|
2021-08-28 04:43:14 +00:00
|
|
|
loaded_config = yaml.safe_load(f.read())
|
|
|
|
|
2021-11-19 17:03:14 +00:00
|
|
|
cli_config = cli.config["fit"].as_dict()
|
2021-08-28 04:43:14 +00:00
|
|
|
assert loaded_config["model"] == cli_config["model"]
|
|
|
|
assert loaded_config["data"] == cli_config["data"]
|
|
|
|
assert loaded_config["trainer"] == cli_config["trainer"]
|
2021-04-28 08:34:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def any_model_any_data_cli():
|
2021-07-26 11:37:35 +00:00
|
|
|
LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True)
|
2021-04-28 08:34:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_help():
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
cli_args = ["any.py", "fit", "--help"]
|
2021-04-28 08:34:32 +00:00
|
|
|
out = StringIO()
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
2021-04-28 08:34:32 +00:00
|
|
|
any_model_any_data_cli()
|
2021-08-09 15:26:53 +00:00
|
|
|
out = out.getvalue()
|
2021-04-28 08:34:32 +00:00
|
|
|
|
2021-08-09 15:26:53 +00:00
|
|
|
assert "--print_config" in out
|
|
|
|
assert "--config" in out
|
|
|
|
assert "--seed_everything" in out
|
|
|
|
assert "--model.help" in out
|
|
|
|
assert "--data.help" in out
|
2021-04-28 08:34:32 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
skip_params = {"self"}
|
2021-04-28 08:34:32 +00:00
|
|
|
for param in inspect.signature(Trainer.__init__).parameters.keys():
|
|
|
|
if param not in skip_params:
|
2021-08-09 15:26:53 +00:00
|
|
|
assert f"--trainer.{param}" in out
|
2021-04-28 08:34:32 +00:00
|
|
|
|
2022-06-14 23:53:54 +00:00
|
|
|
cli_args = ["any.py", "fit", "--data.help=pytorch_lightning.demos.boring_classes.BoringDataModule"]
|
2021-04-28 08:34:32 +00:00
|
|
|
out = StringIO()
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
2021-04-28 08:34:32 +00:00
|
|
|
any_model_any_data_cli()
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert "--data.init_args.data_dir" in out.getvalue()
|
2021-04-28 08:34:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_print_config():
|
|
|
|
cli_args = [
|
2021-07-26 11:37:35 +00:00
|
|
|
"any.py",
|
2021-08-28 04:43:14 +00:00
|
|
|
"predict",
|
2021-07-26 11:37:35 +00:00
|
|
|
"--seed_everything=1234",
|
2022-06-14 23:53:54 +00:00
|
|
|
"--model=pytorch_lightning.demos.boring_classes.BoringModel",
|
|
|
|
"--data=pytorch_lightning.demos.boring_classes.BoringDataModule",
|
2021-07-26 11:37:35 +00:00
|
|
|
"--print_config",
|
2021-04-28 08:34:32 +00:00
|
|
|
]
|
|
|
|
out = StringIO()
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
2021-04-28 08:34:32 +00:00
|
|
|
any_model_any_data_cli()
|
|
|
|
|
2022-04-07 12:22:51 +00:00
|
|
|
text = out.getvalue()
|
|
|
|
# test dump_header
|
|
|
|
assert text.startswith(f"# pytorch_lightning=={__version__}")
|
|
|
|
|
|
|
|
outval = yaml.safe_load(text)
|
2021-07-26 11:37:35 +00:00
|
|
|
assert outval["seed_everything"] == 1234
|
2022-05-02 15:42:12 +00:00
|
|
|
assert outval["model"]["class_path"] == "pytorch_lightning.demos.boring_classes.BoringModel"
|
|
|
|
assert outval["data"]["class_path"] == "pytorch_lightning.demos.boring_classes.BoringDataModule"
|
2021-08-28 04:43:14 +00:00
|
|
|
assert outval["ckpt_path"] is None
|
2021-04-28 08:34:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_submodules(tmpdir):
|
|
|
|
class MainModule(BoringModel):
|
2021-07-26 11:37:35 +00:00
|
|
|
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
|
2021-04-28 08:34:32 +00:00
|
|
|
super().__init__()
|
|
|
|
self.submodule1 = submodule1
|
|
|
|
self.submodule2 = submodule2
|
|
|
|
|
|
|
|
config = """model:
|
|
|
|
main_param: 2
|
|
|
|
submodule1:
|
2022-05-02 15:42:12 +00:00
|
|
|
class_path: pytorch_lightning.demos.boring_classes.BoringModel
|
2021-04-28 08:34:32 +00:00
|
|
|
submodule2:
|
2022-05-02 15:42:12 +00:00
|
|
|
class_path: pytorch_lightning.demos.boring_classes.BoringModel
|
2021-04-28 08:34:32 +00:00
|
|
|
"""
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "config.yaml"
|
|
|
|
with open(config_path, "w") as f:
|
2021-04-28 08:34:32 +00:00
|
|
|
f.write(config)
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]
|
2021-04-28 08:34:32 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-08-28 04:43:14 +00:00
|
|
|
cli = LightningCLI(MainModule, run=False)
|
2021-04-28 08:34:32 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert cli.config["model"]["main_param"] == 2
|
2021-06-12 11:13:14 +00:00
|
|
|
assert isinstance(cli.model.submodule1, BoringModel)
|
|
|
|
assert isinstance(cli.model.submodule2, BoringModel)
|
2021-06-04 05:43:43 +00:00
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required")
|
2021-06-04 05:43:43 +00:00
|
|
|
def test_lightning_cli_torch_modules(tmpdir):
|
|
|
|
class TestModule(BoringModel):
|
2021-07-26 11:37:35 +00:00
|
|
|
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
|
2021-06-04 05:43:43 +00:00
|
|
|
super().__init__()
|
|
|
|
self.activation = activation
|
|
|
|
self.transform = transform
|
|
|
|
|
|
|
|
config = """model:
|
|
|
|
activation:
|
|
|
|
class_path: torch.nn.LeakyReLU
|
|
|
|
init_args:
|
|
|
|
negative_slope: 0.2
|
|
|
|
transform:
|
|
|
|
- class_path: torchvision.transforms.Resize
|
|
|
|
init_args:
|
|
|
|
size: 64
|
|
|
|
- class_path: torchvision.transforms.CenterCrop
|
|
|
|
init_args:
|
|
|
|
size: 64
|
|
|
|
"""
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "config.yaml"
|
|
|
|
with open(config_path, "w") as f:
|
2021-06-04 05:43:43 +00:00
|
|
|
f.write(config)
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]
|
2021-06-04 05:43:43 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-08-28 04:43:14 +00:00
|
|
|
cli = LightningCLI(TestModule, run=False)
|
2021-06-04 05:43:43 +00:00
|
|
|
|
|
|
|
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
|
|
|
|
assert cli.model.activation.negative_slope == 0.2
|
|
|
|
assert len(cli.model.transform) == 2
|
|
|
|
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
|
2021-06-12 11:13:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
class BoringModelRequiredClasses(BoringModel):
|
2021-07-26 11:37:35 +00:00
|
|
|
def __init__(self, num_classes: int, batch_size: int = 8):
|
2021-06-12 11:13:14 +00:00
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
|
|
class BoringDataModuleBatchSizeAndClasses(BoringDataModule):
|
2021-07-26 11:37:35 +00:00
|
|
|
def __init__(self, batch_size: int = 8):
|
2021-06-12 11:13:14 +00:00
|
|
|
super().__init__()
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.num_classes = 5 # only available after instantiation
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_link_arguments(tmpdir):
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
2021-07-26 11:37:35 +00:00
|
|
|
parser.link_arguments("data.batch_size", "model.batch_size")
|
|
|
|
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
|
2021-06-12 11:13:14 +00:00
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
cli_args = [f"--trainer.default_root_dir={tmpdir}", "--data.batch_size=12"]
|
2021-06-12 11:13:14 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-08-28 04:43:14 +00:00
|
|
|
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
|
2021-06-12 11:13:14 +00:00
|
|
|
|
|
|
|
assert cli.model.batch_size == 12
|
|
|
|
assert cli.model.num_classes == 5
|
|
|
|
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
2021-07-26 11:37:35 +00:00
|
|
|
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
|
|
|
|
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
|
2021-06-12 11:13:14 +00:00
|
|
|
|
2022-06-15 22:10:49 +00:00
|
|
|
cli_args[-1] = "--model=tests_pytorch.utilities.test_cli.BoringModelRequiredClasses"
|
2021-06-12 11:13:14 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-08-28 04:43:14 +00:00
|
|
|
cli = MyLightningCLI(
|
|
|
|
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
|
|
|
|
)
|
2021-06-12 11:13:14 +00:00
|
|
|
|
|
|
|
assert cli.model.batch_size == 8
|
|
|
|
assert cli.model.num_classes == 5
|
2021-06-21 15:58:02 +00:00
|
|
|
|
|
|
|
|
2021-07-07 17:56:13 +00:00
|
|
|
class EarlyExitTestModel(BoringModel):
|
|
|
|
def on_fit_start(self):
|
2021-11-17 22:41:50 +00:00
|
|
|
raise MisconfigurationException("Error on fit start")
|
2021-07-07 17:56:13 +00:00
|
|
|
|
|
|
|
|
2022-02-28 13:27:42 +00:00
|
|
|
@RunIf(skip_windows=True)
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("logger", (False, True))
|
2022-02-28 13:27:42 +00:00
|
|
|
@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
|
|
|
|
def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
|
2022-03-27 21:31:20 +00:00
|
|
|
from torch.multiprocessing import ProcessRaisedException
|
2022-02-28 13:27:42 +00:00
|
|
|
|
2021-11-17 22:41:50 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
|
2022-02-28 13:27:42 +00:00
|
|
|
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
|
2021-11-17 22:41:50 +00:00
|
|
|
):
|
2021-07-07 17:56:13 +00:00
|
|
|
LightningCLI(
|
|
|
|
EarlyExitTestModel,
|
|
|
|
trainer_defaults={
|
2021-07-26 11:37:35 +00:00
|
|
|
"default_root_dir": str(tmpdir),
|
|
|
|
"logger": logger,
|
|
|
|
"max_steps": 1,
|
|
|
|
"max_epochs": 1,
|
2022-02-28 13:27:42 +00:00
|
|
|
"strategy": strategy,
|
|
|
|
"accelerator": "auto",
|
|
|
|
"devices": 1,
|
2021-07-26 11:37:35 +00:00
|
|
|
},
|
2021-07-07 17:56:13 +00:00
|
|
|
)
|
|
|
|
if logger:
|
2021-07-26 11:37:35 +00:00
|
|
|
config_dir = tmpdir / "lightning_logs"
|
2021-07-07 17:56:13 +00:00
|
|
|
# no more version dirs should get created
|
2021-07-26 11:37:35 +00:00
|
|
|
assert os.listdir(config_dir) == ["version_0"]
|
|
|
|
config_path = config_dir / "version_0" / "config.yaml"
|
2021-07-07 17:56:13 +00:00
|
|
|
else:
|
2021-07-26 11:37:35 +00:00
|
|
|
config_path = tmpdir / "config.yaml"
|
2021-07-07 17:56:13 +00:00
|
|
|
assert os.path.isfile(config_path)
|
|
|
|
|
|
|
|
|
2021-06-21 15:58:02 +00:00
|
|
|
def test_cli_config_overwrite(tmpdir):
|
2021-07-26 11:37:35 +00:00
|
|
|
trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1}
|
2021-06-21 15:58:02 +00:00
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
argv = ["any.py", "fit"]
|
|
|
|
with mock.patch("sys.argv", argv):
|
2021-06-21 15:58:02 +00:00
|
|
|
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", argv), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
|
2021-06-21 15:58:02 +00:00
|
|
|
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", argv):
|
2021-06-21 15:58:02 +00:00
|
|
|
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
@pytest.mark.parametrize("run", (False, True))
|
|
|
|
def test_lightning_cli_optimizer(tmpdir, run):
|
2021-07-01 10:04:11 +00:00
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
|
|
|
parser.add_optimizer_args(torch.optim.Adam)
|
|
|
|
|
2021-12-01 15:41:22 +00:00
|
|
|
match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`"
|
2021-08-28 04:43:14 +00:00
|
|
|
argv = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1"] if run else []
|
|
|
|
with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match):
|
|
|
|
cli = MyLightningCLI(BoringModel, run=run)
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
|
2021-08-28 04:43:14 +00:00
|
|
|
|
|
|
|
if not run:
|
|
|
|
optimizer = cli.model.configure_optimizers()
|
|
|
|
assert isinstance(optimizer, torch.optim.Adam)
|
|
|
|
else:
|
|
|
|
assert len(cli.trainer.optimizers) == 1
|
|
|
|
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
|
2022-01-18 19:23:32 +00:00
|
|
|
assert len(cli.trainer.lr_scheduler_configs) == 0
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir):
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
|
|
|
parser.add_optimizer_args(torch.optim.Adam)
|
|
|
|
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
|
|
|
|
|
2021-09-16 13:04:51 +00:00
|
|
|
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"]
|
2021-07-01 10:04:11 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-07-01 10:04:11 +00:00
|
|
|
cli = MyLightningCLI(BoringModel)
|
|
|
|
|
|
|
|
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
|
|
|
|
assert len(cli.trainer.optimizers) == 1
|
|
|
|
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
|
2022-01-18 19:23:32 +00:00
|
|
|
assert len(cli.trainer.lr_scheduler_configs) == 1
|
|
|
|
assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
|
|
|
assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
|
2022-02-02 22:44:00 +00:00
|
|
|
def test_cli_no_need_configure_optimizers():
|
|
|
|
class BoringModel(LightningModule):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer = torch.nn.Linear(32, 2)
|
|
|
|
|
|
|
|
def training_step(self, *_):
|
|
|
|
...
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
...
|
|
|
|
|
|
|
|
# did not define `configure_optimizers`
|
|
|
|
|
|
|
|
from pytorch_lightning.trainer.configuration_validator import __verify_train_val_loop_configuration
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer._run_train"
|
|
|
|
) as run, mock.patch(
|
|
|
|
"pytorch_lightning.trainer.configuration_validator.__verify_train_val_loop_configuration",
|
|
|
|
wraps=__verify_train_val_loop_configuration,
|
|
|
|
) as verify:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
run.assert_called_once()
|
|
|
|
verify.assert_called_once_with(cli.trainer, cli.model)
|
|
|
|
|
|
|
|
|
2021-07-01 10:04:11 +00:00
|
|
|
def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir):
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
|
|
|
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
|
|
|
|
parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR))
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
optimizer_arg = dict(class_path="torch.optim.Adam", init_args=dict(lr=0.01))
|
|
|
|
lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50))
|
2021-07-01 10:04:11 +00:00
|
|
|
cli_args = [
|
2021-08-28 04:43:14 +00:00
|
|
|
"fit",
|
2021-07-26 11:37:35 +00:00
|
|
|
f"--trainer.default_root_dir={tmpdir}",
|
|
|
|
"--trainer.max_epochs=1",
|
|
|
|
f"--optimizer={json.dumps(optimizer_arg)}",
|
|
|
|
f"--lr_scheduler={json.dumps(lr_scheduler_arg)}",
|
2021-07-01 10:04:11 +00:00
|
|
|
]
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-07-01 10:04:11 +00:00
|
|
|
cli = MyLightningCLI(BoringModel)
|
|
|
|
|
|
|
|
assert len(cli.trainer.optimizers) == 1
|
|
|
|
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
|
2022-01-18 19:23:32 +00:00
|
|
|
assert len(cli.trainer.lr_scheduler_configs) == 1
|
|
|
|
assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR)
|
|
|
|
assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
@pytest.mark.parametrize("use_generic_base_class", [False, True])
|
|
|
|
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class, tmpdir):
|
2021-07-01 10:04:11 +00:00
|
|
|
class MyLightningCLI(LightningCLI):
|
|
|
|
def add_arguments_to_parser(self, parser):
|
2021-09-17 17:00:46 +00:00
|
|
|
parser.add_optimizer_args(
|
2022-05-03 12:16:37 +00:00
|
|
|
(torch.optim.Optimizer,) if use_generic_base_class else torch.optim.Adam,
|
2021-09-17 17:00:46 +00:00
|
|
|
nested_key="optim1",
|
|
|
|
link_to="model.optim1",
|
|
|
|
)
|
2021-07-26 11:37:35 +00:00
|
|
|
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
|
2021-09-17 17:00:46 +00:00
|
|
|
parser.add_lr_scheduler_args(
|
2022-05-03 12:16:37 +00:00
|
|
|
LRSchedulerTypeTuple if use_generic_base_class else torch.optim.lr_scheduler.ExponentialLR,
|
2021-09-17 17:00:46 +00:00
|
|
|
link_to="model.scheduler",
|
|
|
|
)
|
2021-07-01 10:04:11 +00:00
|
|
|
|
|
|
|
class TestModel(BoringModel):
|
2021-07-26 11:37:35 +00:00
|
|
|
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
|
2021-07-01 10:04:11 +00:00
|
|
|
super().__init__()
|
|
|
|
self.optim1 = instantiate_class(self.parameters(), optim1)
|
|
|
|
self.optim2 = instantiate_class(self.parameters(), optim2)
|
|
|
|
self.scheduler = instantiate_class(self.optim1, scheduler)
|
|
|
|
|
2022-04-04 17:06:28 +00:00
|
|
|
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1"]
|
2022-05-03 12:16:37 +00:00
|
|
|
if use_generic_base_class:
|
2021-09-17 17:00:46 +00:00
|
|
|
cli_args += [
|
|
|
|
"--optim1",
|
|
|
|
"Adam",
|
|
|
|
"--optim1.weight_decay",
|
|
|
|
"0.001",
|
|
|
|
"--optim2=SGD",
|
|
|
|
"--optim2.lr=0.01",
|
|
|
|
"--lr_scheduler=ExponentialLR",
|
|
|
|
]
|
|
|
|
else:
|
2022-05-03 12:16:37 +00:00
|
|
|
cli_args += ["--optim2=SGD", "--optim2.lr=0.01"]
|
2022-04-04 17:06:28 +00:00
|
|
|
cli_args += ["--lr_scheduler.gamma=0.2"]
|
2021-07-01 10:04:11 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
2021-07-01 10:04:11 +00:00
|
|
|
cli = MyLightningCLI(TestModel)
|
|
|
|
|
|
|
|
assert isinstance(cli.model.optim1, torch.optim.Adam)
|
|
|
|
assert isinstance(cli.model.optim2, torch.optim.SGD)
|
2021-09-17 17:00:46 +00:00
|
|
|
assert cli.model.optim2.param_groups[0]["lr"] == 0.01
|
2021-07-01 10:04:11 +00:00
|
|
|
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
2021-08-10 13:01:36 +00:00
|
|
|
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
|
|
|
|
def test_lightning_cli_trainer_fn(fn):
|
|
|
|
class TestCLI(LightningCLI):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.called = []
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def before_fit(self):
|
|
|
|
self.called.append("before_fit")
|
|
|
|
|
|
|
|
def fit(self, **_):
|
|
|
|
self.called.append("fit")
|
|
|
|
|
|
|
|
def after_fit(self):
|
|
|
|
self.called.append("after_fit")
|
|
|
|
|
|
|
|
def before_validate(self):
|
|
|
|
self.called.append("before_validate")
|
|
|
|
|
|
|
|
def validate(self, **_):
|
|
|
|
self.called.append("validate")
|
|
|
|
|
|
|
|
def after_validate(self):
|
|
|
|
self.called.append("after_validate")
|
|
|
|
|
|
|
|
def before_test(self):
|
|
|
|
self.called.append("before_test")
|
|
|
|
|
|
|
|
def test(self, **_):
|
|
|
|
self.called.append("test")
|
|
|
|
|
|
|
|
def after_test(self):
|
|
|
|
self.called.append("after_test")
|
|
|
|
|
|
|
|
def before_predict(self):
|
|
|
|
self.called.append("before_predict")
|
|
|
|
|
|
|
|
def predict(self, **_):
|
|
|
|
self.called.append("predict")
|
|
|
|
|
|
|
|
def after_predict(self):
|
|
|
|
self.called.append("after_predict")
|
|
|
|
|
|
|
|
def before_tune(self):
|
|
|
|
self.called.append("before_tune")
|
|
|
|
|
|
|
|
def tune(self, **_):
|
|
|
|
self.called.append("tune")
|
|
|
|
|
|
|
|
def after_tune(self):
|
|
|
|
self.called.append("after_tune")
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", fn]):
|
|
|
|
cli = TestCLI(BoringModel)
|
|
|
|
assert cli.called == [f"before_{fn}", fn, f"after_{fn}"]
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_subcommands():
|
|
|
|
subcommands = LightningCLI.subcommands()
|
|
|
|
trainer = Trainer()
|
|
|
|
for subcommand, exclude in subcommands.items():
|
|
|
|
fn = getattr(trainer, subcommand)
|
|
|
|
parameters = list(inspect.signature(fn).parameters)
|
|
|
|
for e in exclude:
|
|
|
|
# if this fails, it's because the parameter has been removed from the associated `Trainer` function
|
|
|
|
# and the `LightningCLI` subcommand exclusion list needs to be updated
|
|
|
|
assert e in parameters
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_custom_subcommand():
|
|
|
|
class TestTrainer(Trainer):
|
|
|
|
def foo(self, model: LightningModule, x: int, y: float = 1.0):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Sample extra function.
|
2021-08-28 04:43:14 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model: A model
|
|
|
|
x: The x
|
|
|
|
y: The y
|
|
|
|
"""
|
|
|
|
|
|
|
|
class TestCLI(LightningCLI):
|
|
|
|
@staticmethod
|
|
|
|
def subcommands():
|
|
|
|
subcommands = LightningCLI.subcommands()
|
|
|
|
subcommands["foo"] = {"model"}
|
|
|
|
return subcommands
|
|
|
|
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
TestCLI(BoringModel, trainer_class=TestTrainer)
|
|
|
|
out = out.getvalue()
|
|
|
|
assert "Sample extra function." in out
|
|
|
|
assert "{fit,validate,test,predict,tune,foo}" in out
|
|
|
|
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", ["any.py", "foo", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
TestCLI(BoringModel, trainer_class=TestTrainer)
|
|
|
|
out = out.getvalue()
|
|
|
|
assert "A model" not in out
|
|
|
|
assert "Sample extra function:" in out
|
|
|
|
assert "--x X" in out
|
|
|
|
assert "The x (required, type: int)" in out
|
|
|
|
assert "--y Y" in out
|
|
|
|
assert "The y (type: float, default: 1.0)" in out
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_run():
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
assert cli.trainer.global_step == 0
|
|
|
|
assert isinstance(cli.trainer, Trainer)
|
|
|
|
assert isinstance(cli.model, LightningModule)
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit"]):
|
|
|
|
cli = LightningCLI(BoringModel, trainer_defaults={"max_steps": 1, "max_epochs": 1})
|
|
|
|
assert cli.trainer.global_step == 1
|
2021-08-10 13:01:36 +00:00
|
|
|
assert isinstance(cli.trainer, Trainer)
|
|
|
|
assert isinstance(cli.model, LightningModule)
|
2021-08-28 04:43:14 +00:00
|
|
|
|
|
|
|
|
2021-09-22 14:19:02 +00:00
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self, foo, bar=5):
|
|
|
|
super().__init__()
|
|
|
|
self.foo = foo
|
|
|
|
self.bar = bar
|
|
|
|
|
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
def test_lightning_cli_model_short_arguments():
|
2021-09-22 14:19:02 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer._fit_impl"
|
2022-05-03 12:16:37 +00:00
|
|
|
) as run, mock_subclasses(LightningModule, BoringModel, TestModel):
|
2021-09-22 14:19:02 +00:00
|
|
|
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
|
|
|
|
assert isinstance(cli.model, BoringModel)
|
2021-10-25 19:05:31 +00:00
|
|
|
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
|
2021-09-22 14:19:02 +00:00
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses(
|
|
|
|
LightningModule, BoringModel, TestModel
|
|
|
|
):
|
2021-09-22 14:19:02 +00:00
|
|
|
cli = LightningCLI(run=False)
|
|
|
|
assert isinstance(cli.model, TestModel)
|
|
|
|
assert cli.model.foo == 123
|
|
|
|
assert cli.model.bar == 5
|
|
|
|
|
|
|
|
|
2021-10-20 00:49:48 +00:00
|
|
|
class MyDataModule(BoringDataModule):
|
|
|
|
def __init__(self, foo, bar=5):
|
|
|
|
super().__init__()
|
|
|
|
self.foo = foo
|
|
|
|
self.bar = bar
|
|
|
|
|
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
def test_lightning_cli_datamodule_short_arguments():
|
2021-10-20 00:49:48 +00:00
|
|
|
# with set model
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer._fit_impl"
|
2022-05-03 12:16:37 +00:00
|
|
|
) as run, mock_subclasses(LightningDataModule, BoringDataModule):
|
2021-10-20 00:49:48 +00:00
|
|
|
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
|
|
|
|
assert isinstance(cli.datamodule, BoringDataModule)
|
2021-10-25 19:05:31 +00:00
|
|
|
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
|
2021-10-20 00:49:48 +00:00
|
|
|
|
2022-05-16 22:29:18 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses(
|
|
|
|
LightningDataModule, MyDataModule
|
|
|
|
):
|
2021-10-20 00:49:48 +00:00
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
assert isinstance(cli.datamodule, MyDataModule)
|
|
|
|
assert cli.datamodule.foo == 123
|
|
|
|
assert cli.datamodule.bar == 5
|
|
|
|
|
|
|
|
# with configurable model
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer._fit_impl"
|
2022-05-03 12:16:37 +00:00
|
|
|
) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule):
|
2021-10-20 00:49:48 +00:00
|
|
|
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
|
|
|
|
assert isinstance(cli.model, BoringModel)
|
|
|
|
assert isinstance(cli.datamodule, BoringDataModule)
|
2021-10-25 19:05:31 +00:00
|
|
|
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
|
2021-10-20 00:49:48 +00:00
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses(
|
|
|
|
LightningModule, BoringModel
|
2022-05-16 22:29:18 +00:00
|
|
|
), mock_subclasses(LightningDataModule, MyDataModule):
|
2021-10-20 00:49:48 +00:00
|
|
|
cli = LightningCLI(run=False)
|
|
|
|
assert isinstance(cli.model, BoringModel)
|
|
|
|
assert isinstance(cli.datamodule, MyDataModule)
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
# data was not passed but we are adding it automatically because there are datamodules registered
|
|
|
|
assert "data" in cli.parser.groups
|
|
|
|
assert not hasattr(cli.parser.groups["data"], "group_class")
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(BoringModel, BoringDataModule, run=False)
|
|
|
|
# since we are passing the DataModule, that's whats added to the parser
|
|
|
|
assert cli.parser.groups["data"].group_class is BoringDataModule
|
|
|
|
|
|
|
|
|
2021-09-17 17:54:06 +00:00
|
|
|
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
|
2022-06-02 01:00:02 +00:00
|
|
|
def test_callbacks_append(use_class_path_callbacks):
|
2022-03-08 17:26:10 +00:00
|
|
|
|
2021-09-17 17:00:46 +00:00
|
|
|
"""This test validates registries are used when simplified command line are being used."""
|
|
|
|
cli_args = [
|
|
|
|
"--optimizer",
|
|
|
|
"Adam",
|
|
|
|
"--optimizer.lr",
|
|
|
|
"0.0001",
|
2022-06-02 01:00:02 +00:00
|
|
|
"--trainer.callbacks+=LearningRateMonitor",
|
2021-09-17 17:54:06 +00:00
|
|
|
"--trainer.callbacks.logging_interval=epoch",
|
|
|
|
"--trainer.callbacks.log_momentum=True",
|
2021-09-22 14:19:02 +00:00
|
|
|
"--model=BoringModel",
|
2022-06-02 01:00:02 +00:00
|
|
|
"--trainer.callbacks+",
|
|
|
|
"ModelCheckpoint",
|
2021-09-17 17:54:06 +00:00
|
|
|
"--trainer.callbacks.monitor=loss",
|
2021-09-17 17:00:46 +00:00
|
|
|
"--lr_scheduler",
|
|
|
|
"StepLR",
|
|
|
|
"--lr_scheduler.step_size=50",
|
|
|
|
]
|
|
|
|
|
2021-09-17 17:54:06 +00:00
|
|
|
extras = []
|
|
|
|
if use_class_path_callbacks:
|
|
|
|
callbacks = [
|
|
|
|
{"class_path": "pytorch_lightning.callbacks.Callback"},
|
|
|
|
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
|
|
|
|
]
|
2022-06-02 01:00:02 +00:00
|
|
|
cli_args += [f"--trainer.callbacks+={json.dumps(callbacks)}"]
|
2021-09-17 17:54:06 +00:00
|
|
|
extras = [Callback, Callback]
|
|
|
|
|
2022-05-03 12:16:37 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(LightningModule, BoringModel):
|
2021-09-22 14:19:02 +00:00
|
|
|
cli = LightningCLI(run=False)
|
2021-09-17 17:00:46 +00:00
|
|
|
|
2021-09-22 14:19:02 +00:00
|
|
|
assert isinstance(cli.model, BoringModel)
|
2021-09-17 17:00:46 +00:00
|
|
|
optimizers, lr_scheduler = cli.model.configure_optimizers()
|
|
|
|
assert isinstance(optimizers[0], torch.optim.Adam)
|
|
|
|
assert optimizers[0].param_groups[0]["lr"] == 0.0001
|
|
|
|
assert lr_scheduler[0].step_size == 50
|
|
|
|
|
2021-09-17 17:54:06 +00:00
|
|
|
callback_types = [type(c) for c in cli.trainer.callbacks]
|
|
|
|
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
|
|
|
|
assert all(t in callback_types for t in expected)
|
|
|
|
|
|
|
|
|
2021-09-17 17:00:46 +00:00
|
|
|
def test_optimizers_and_lr_schedulers_reload(tmpdir):
|
|
|
|
base = ["any.py", "--trainer.max_epochs=1"]
|
|
|
|
input = base + [
|
|
|
|
"--lr_scheduler",
|
|
|
|
"OneCycleLR",
|
|
|
|
"--lr_scheduler.total_steps=10",
|
|
|
|
"--lr_scheduler.max_lr=1",
|
|
|
|
"--optimizer",
|
|
|
|
"Adam",
|
|
|
|
"--optimizer.lr=0.1",
|
|
|
|
]
|
|
|
|
|
|
|
|
# save config
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
|
|
|
|
|
|
|
# validate yaml
|
|
|
|
yaml_config = out.getvalue()
|
|
|
|
dict_config = yaml.safe_load(yaml_config)
|
2022-04-11 12:00:48 +00:00
|
|
|
assert dict_config["optimizer"]["class_path"] == "torch.optim.Adam"
|
2021-09-17 17:00:46 +00:00
|
|
|
assert dict_config["optimizer"]["init_args"]["lr"] == 0.1
|
|
|
|
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
|
|
|
|
|
|
|
# reload config
|
|
|
|
yaml_config_file = tmpdir / "config.yaml"
|
|
|
|
yaml_config_file.write_text(yaml_config, "utf-8")
|
|
|
|
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
|
|
|
|
|
|
|
|
|
|
|
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir):
|
|
|
|
class TestLightningCLI(LightningCLI):
|
|
|
|
def __init__(self, *args):
|
|
|
|
super().__init__(*args, run=False)
|
|
|
|
|
|
|
|
def add_arguments_to_parser(self, parser):
|
2022-06-21 15:12:04 +00:00
|
|
|
parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config")
|
2021-09-17 17:00:46 +00:00
|
|
|
parser.add_optimizer_args(
|
|
|
|
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
|
|
|
|
)
|
2022-06-21 15:12:04 +00:00
|
|
|
parser.add_lr_scheduler_args(link_to="model.sch_config")
|
2021-09-17 17:00:46 +00:00
|
|
|
parser.add_argument("--something", type=str, nargs="+")
|
|
|
|
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict):
|
|
|
|
super().__init__()
|
|
|
|
self.opt1_config = opt1_config
|
|
|
|
self.opt2_config = opt2_config
|
|
|
|
self.sch_config = sch_config
|
|
|
|
opt1 = instantiate_class(self.parameters(), opt1_config)
|
|
|
|
assert isinstance(opt1, torch.optim.Adam)
|
|
|
|
opt2 = instantiate_class(self.parameters(), opt2_config)
|
|
|
|
assert isinstance(opt2, torch.optim.ASGD)
|
|
|
|
sch = instantiate_class(opt1, sch_config)
|
|
|
|
assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR)
|
|
|
|
|
|
|
|
base = ["any.py", "--trainer.max_epochs=1"]
|
|
|
|
input = base + [
|
|
|
|
"--lr_scheduler",
|
|
|
|
"OneCycleLR",
|
|
|
|
"--lr_scheduler.total_steps=10",
|
|
|
|
"--lr_scheduler.max_lr=1",
|
|
|
|
"--opt1",
|
|
|
|
"Adam",
|
2022-04-04 17:06:28 +00:00
|
|
|
"--opt2=ASGD",
|
2021-09-17 17:00:46 +00:00
|
|
|
"--opt2.lr=0.1",
|
|
|
|
"--lr_scheduler.anneal_strategy=linear",
|
|
|
|
"--something",
|
|
|
|
"a",
|
|
|
|
"b",
|
|
|
|
"c",
|
|
|
|
]
|
|
|
|
|
|
|
|
# save config
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
TestLightningCLI(TestModel)
|
|
|
|
|
|
|
|
# validate yaml
|
|
|
|
yaml_config = out.getvalue()
|
|
|
|
dict_config = yaml.safe_load(yaml_config)
|
2022-04-11 12:00:48 +00:00
|
|
|
assert dict_config["opt1"]["class_path"] == "torch.optim.Adam"
|
|
|
|
assert dict_config["opt2"]["class_path"] == "torch.optim.ASGD"
|
2021-09-17 17:00:46 +00:00
|
|
|
assert dict_config["opt2"]["init_args"]["lr"] == 0.1
|
|
|
|
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
|
|
|
assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear"
|
|
|
|
assert dict_config["something"] == ["a", "b", "c"]
|
|
|
|
|
|
|
|
# reload config
|
|
|
|
yaml_config_file = tmpdir / "config.yaml"
|
|
|
|
yaml_config_file.write_text(yaml_config, "utf-8")
|
|
|
|
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
|
|
|
|
cli = TestLightningCLI(TestModel)
|
|
|
|
|
2022-04-11 12:00:48 +00:00
|
|
|
assert cli.model.opt1_config["class_path"] == "torch.optim.Adam"
|
|
|
|
assert cli.model.opt2_config["class_path"] == "torch.optim.ASGD"
|
2021-09-17 17:00:46 +00:00
|
|
|
assert cli.model.opt2_config["init_args"]["lr"] == 0.1
|
|
|
|
assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
|
|
|
assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear"
|
|
|
|
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
def test_lightning_cli_config_with_subcommand():
|
|
|
|
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.test", autospec=True
|
|
|
|
) as test_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
|
|
|
|
assert cli.trainer.limit_test_batches == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_config_before_subcommand():
|
|
|
|
config = {
|
|
|
|
"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"},
|
|
|
|
"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"},
|
|
|
|
}
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.test", autospec=True
|
|
|
|
) as test_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
|
|
|
|
assert cli.trainer.limit_test_batches == 1
|
|
|
|
|
2022-01-20 12:35:43 +00:00
|
|
|
save_config_callback = cli.trainer.callbacks[0]
|
|
|
|
assert save_config_callback.config.trainer.limit_test_batches == 1
|
|
|
|
assert save_config_callback.parser.subcommand == "test"
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.validate", autospec=True
|
|
|
|
) as validate_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
|
|
|
|
assert cli.trainer.limit_val_batches == 1
|
|
|
|
|
2022-01-20 12:35:43 +00:00
|
|
|
save_config_callback = cli.trainer.callbacks[0]
|
|
|
|
assert save_config_callback.config.trainer.limit_val_batches == 1
|
|
|
|
assert save_config_callback.parser.subcommand == "validate"
|
|
|
|
|
2021-08-28 04:43:14 +00:00
|
|
|
|
|
|
|
def test_lightning_cli_config_before_subcommand_two_configs():
|
|
|
|
config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}
|
|
|
|
config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.test", autospec=True
|
|
|
|
) as test_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
|
|
|
|
assert cli.trainer.limit_test_batches == 1
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.validate", autospec=True
|
|
|
|
) as validate_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
|
|
|
|
assert cli.trainer.limit_val_batches == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_config_after_subcommand():
|
|
|
|
config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}
|
|
|
|
with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.test", autospec=True
|
|
|
|
) as test_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
|
|
|
|
assert cli.trainer.limit_test_batches == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_config_before_and_after_subcommand():
|
|
|
|
config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
|
|
|
config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"}
|
|
|
|
with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.test", autospec=True
|
|
|
|
) as test_mock:
|
|
|
|
cli = LightningCLI(BoringModel)
|
|
|
|
|
|
|
|
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
|
|
|
|
assert cli.trainer.limit_test_batches == 1
|
|
|
|
assert cli.trainer.fast_dev_run == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir):
|
|
|
|
fit_config = {"trainer": {"limit_train_batches": 2}}
|
|
|
|
fit_config_path = tmpdir / "fit.yaml"
|
|
|
|
fit_config_path.write_text(str(fit_config), "utf8")
|
|
|
|
|
|
|
|
validate_config = {"trainer": {"limit_val_batches": 3}}
|
|
|
|
validate_config_path = tmpdir / "validate.yaml"
|
|
|
|
validate_config_path.write_text(str(validate_config), "utf8")
|
|
|
|
|
|
|
|
parser_kwargs = {
|
|
|
|
"fit": {"default_config_files": [str(fit_config_path)]},
|
|
|
|
"validate": {"default_config_files": [str(validate_config_path)]},
|
|
|
|
}
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.fit", autospec=True
|
|
|
|
) as fit_mock:
|
|
|
|
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
|
|
|
|
fit_mock.assert_called()
|
|
|
|
assert cli.trainer.limit_train_batches == 2
|
|
|
|
assert cli.trainer.limit_val_batches == 1.0
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.validate", autospec=True
|
|
|
|
) as validate_mock:
|
|
|
|
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
|
|
|
|
validate_mock.assert_called()
|
|
|
|
assert cli.trainer.limit_train_batches == 1.0
|
|
|
|
assert cli.trainer.limit_val_batches == 3
|
2021-09-02 22:56:30 +00:00
|
|
|
|
|
|
|
|
2022-02-28 10:17:49 +00:00
|
|
|
def test_lightning_cli_subcommands_common_default_config_files(tmpdir):
|
|
|
|
class Model(BoringModel):
|
|
|
|
def __init__(self, foo: int, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.foo = foo
|
|
|
|
|
|
|
|
config = {"fit": {"model": {"foo": 123}}}
|
|
|
|
config_path = tmpdir / "default.yaml"
|
|
|
|
config_path.write_text(str(config), "utf8")
|
|
|
|
parser_kwargs = {"default_config_files": [str(config_path)]}
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
|
|
|
|
"pytorch_lightning.Trainer.fit", autospec=True
|
|
|
|
) as fit_mock:
|
|
|
|
cli = LightningCLI(Model, parser_kwargs=parser_kwargs)
|
|
|
|
fit_mock.assert_called()
|
|
|
|
assert cli.model.foo == 123
|
|
|
|
|
|
|
|
|
2021-09-02 22:56:30 +00:00
|
|
|
def test_lightning_cli_reinstantiate_trainer():
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
assert cli.trainer.max_epochs == 1000
|
|
|
|
|
|
|
|
class TestCallback(Callback):
|
|
|
|
...
|
|
|
|
|
|
|
|
# make sure a new trainer can be easily created
|
|
|
|
trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()])
|
|
|
|
# the new config is used
|
|
|
|
assert trainer.max_epochs == 123
|
|
|
|
assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union(
|
|
|
|
{TestCallback}
|
|
|
|
)
|
|
|
|
# the existing config is not updated
|
|
|
|
assert cli.config_init["trainer"]["max_epochs"] is None
|
2021-09-21 13:25:07 +00:00
|
|
|
|
|
|
|
|
2021-11-29 14:12:53 +00:00
|
|
|
def test_cli_configure_optimizers_warning():
|
2021-09-21 13:25:07 +00:00
|
|
|
match = "configure_optimizers` will be overridden by `LightningCLI"
|
|
|
|
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
2021-11-29 14:12:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_cli_help_message():
|
|
|
|
# full class path
|
|
|
|
cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"]
|
|
|
|
classpath_help = StringIO()
|
|
|
|
with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
|
|
|
|
|
|
|
cli_args = ["any.py", "--optimizer.help=Adam"]
|
|
|
|
shorthand_help = StringIO()
|
|
|
|
with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit):
|
|
|
|
LightningCLI(BoringModel, run=False)
|
|
|
|
|
|
|
|
# the help messages should match
|
|
|
|
assert shorthand_help.getvalue() == classpath_help.getvalue()
|
|
|
|
# make sure it's not empty
|
|
|
|
assert "Implements Adam" in shorthand_help.getvalue()
|
2021-12-01 15:41:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_cli_reducelronplateau():
|
|
|
|
with mock.patch(
|
|
|
|
"sys.argv", ["any.py", "--optimizer=Adam", "--lr_scheduler=ReduceLROnPlateau", "--lr_scheduler.monitor=foo"]
|
|
|
|
):
|
|
|
|
cli = LightningCLI(BoringModel, run=False)
|
|
|
|
config = cli.model.configure_optimizers()
|
|
|
|
assert isinstance(config["lr_scheduler"]["scheduler"], ReduceLROnPlateau)
|
|
|
|
assert config["lr_scheduler"]["scheduler"].monitor == "foo"
|
|
|
|
|
|
|
|
|
|
|
|
def test_cli_configureoptimizers_can_be_overridden():
|
|
|
|
class MyCLI(LightningCLI):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(BoringModel, run=False)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def configure_optimizers(self, optimizer, lr_scheduler=None):
|
|
|
|
assert isinstance(self, BoringModel)
|
|
|
|
assert lr_scheduler is None
|
|
|
|
return 123
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]):
|
|
|
|
cli = MyCLI()
|
|
|
|
assert cli.model.configure_optimizers() == 123
|
|
|
|
|
|
|
|
# with no optimization config, we don't override
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = MyCLI()
|
|
|
|
[optimizer], [scheduler] = cli.model.configure_optimizers()
|
|
|
|
assert isinstance(optimizer, SGD)
|
|
|
|
assert isinstance(scheduler, StepLR)
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--lr_scheduler=StepLR"]):
|
|
|
|
cli = MyCLI()
|
|
|
|
[optimizer], [scheduler] = cli.model.configure_optimizers()
|
|
|
|
assert isinstance(optimizer, SGD)
|
|
|
|
assert isinstance(scheduler, StepLR)
|
2022-01-18 15:36:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_cli_parameter_with_lazy_instance_default():
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05)):
|
|
|
|
super().__init__()
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
model = TestModel()
|
|
|
|
assert isinstance(model.activation, torch.nn.LeakyReLU)
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False)
|
|
|
|
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
|
|
|
|
assert cli.model.activation.negative_slope == 0.05
|
|
|
|
assert cli.model.activation is not model.activation
|
2022-02-03 02:58:14 +00:00
|
|
|
|
|
|
|
|
2022-06-21 21:58:41 +00:00
|
|
|
def test_ddpstrategy_instantiation_and_find_unused_parameters():
|
|
|
|
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
|
|
|
|
cli = LightningCLI(
|
|
|
|
BoringModel,
|
|
|
|
run=False,
|
|
|
|
trainer_defaults={"strategy": strategy_default},
|
|
|
|
)
|
|
|
|
|
|
|
|
assert cli.config.trainer.strategy.init_args.find_unused_parameters is True
|
|
|
|
assert isinstance(cli.config_init.trainer.strategy, DDPStrategy)
|
|
|
|
assert cli.config_init.trainer.strategy.process_group_backend == "group"
|
|
|
|
assert strategy_default is not cli.config_init.trainer.strategy
|
|
|
|
|
|
|
|
|
2022-02-03 02:58:14 +00:00
|
|
|
def test_cli_logger_shorthand():
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
|
|
|
|
assert cli.trainer.logger is None
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--trainer.logger=TensorBoardLogger", "--trainer.logger.save_dir=foo"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
|
|
|
|
assert isinstance(cli.trainer.logger, TensorBoardLogger)
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--trainer.logger=False"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False)
|
|
|
|
assert cli.trainer.logger is None
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
|
2022-06-21 21:58:41 +00:00
|
|
|
def _test_logger_init_args(logger_name, init, unresolved={}):
|
|
|
|
cli_args = [f"--trainer.logger={logger_name}"]
|
|
|
|
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
|
|
|
|
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
|
|
|
|
cli_args.append("--print_config")
|
|
|
|
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
LightningCLI(TestModel, run=False)
|
|
|
|
|
|
|
|
data = yaml.safe_load(out.getvalue())["trainer"]["logger"]
|
|
|
|
assert {k: data["init_args"][k] for k in init} == init
|
|
|
|
if unresolved:
|
|
|
|
assert data["dict_kwargs"] == unresolved
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
|
|
|
|
def test_comet_logger_init_args():
|
|
|
|
_test_logger_init_args("CometLogger", {"save_dir": "comet", "workspace": "comet"})
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune-client is required")
|
|
|
|
def test_neptune_logger_init_args():
|
|
|
|
_test_logger_init_args("NeptuneLogger", {"name": "neptune"}, {"description": "neptune"})
|
|
|
|
|
|
|
|
|
|
|
|
def test_tensorboard_logger_init_args():
|
|
|
|
_test_logger_init_args("TensorBoardLogger", {"save_dir": "tb", "name": "tb"})
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
|
|
|
|
def test_wandb_logger_init_args():
|
|
|
|
_test_logger_init_args("WandbLogger", {"save_dir": "wandb", "notes": "wandb"})
|
|
|
|
|
|
|
|
|
2022-04-28 14:24:32 +00:00
|
|
|
def test_cli_auto_seeding():
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is False
|
|
|
|
assert cli.config["seed_everything"] is False
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is True
|
|
|
|
assert isinstance(cli.config["seed_everything"], int)
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is False
|
|
|
|
assert cli.config["seed_everything"] == 3
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is True
|
|
|
|
assert cli.config["seed_everything"] == 3
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=10)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default == 10
|
|
|
|
assert cli.config["seed_everything"] == 3
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "false"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=10)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default == 10
|
|
|
|
assert cli.config["seed_everything"] is False
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "false"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is True
|
|
|
|
assert cli.config["seed_everything"] is False
|
2022-04-28 14:24:32 +00:00
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--seed_everything", "true"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
|
2022-05-31 20:31:25 +00:00
|
|
|
assert cli.seed_everything_default is False
|
|
|
|
assert isinstance(cli.config["seed_everything"], int)
|
|
|
|
|
|
|
|
seed_everything(123)
|
|
|
|
with mock.patch("sys.argv", ["any.py"]):
|
|
|
|
cli = LightningCLI(TestModel, run=False)
|
|
|
|
assert cli.seed_everything_default is True
|
|
|
|
assert cli.config["seed_everything"] == 123 # the original seed is kept
|
2022-06-01 09:00:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_unresolvable_import_paths():
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self, a_func: Callable = torch.softmax):
|
|
|
|
super().__init__()
|
|
|
|
self.a_func = a_func
|
|
|
|
|
|
|
|
out = StringIO()
|
|
|
|
with mock.patch("sys.argv", ["any.py", "--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
|
|
|
|
LightningCLI(TestModel, run=False)
|
|
|
|
|
|
|
|
assert "a_func: torch.softmax" in out.getvalue()
|
2022-06-21 21:58:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_pytorch_profiler_init_args():
|
|
|
|
init = {
|
|
|
|
"dirpath": "profiler",
|
|
|
|
"row_limit": 10,
|
|
|
|
"group_by_input_shapes": True,
|
|
|
|
}
|
|
|
|
unresolved = {
|
|
|
|
"profile_memory": True,
|
|
|
|
"record_shapes": True,
|
|
|
|
}
|
|
|
|
cli_args = ["--trainer.profiler=PyTorchProfiler"]
|
|
|
|
cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()]
|
|
|
|
cli_args += [f"--trainer.profiler.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
|
|
|
|
|
|
|
|
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
|
|
|
cli = LightningCLI(TestModel, run=False)
|
|
|
|
|
|
|
|
assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler)
|
|
|
|
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
|
|
|
|
assert cli.config.trainer.profiler.dict_kwargs == unresolved
|