diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 79070a3eb4..b8c6d2ac31 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -300,7 +300,6 @@ class Trainer: MisconfigurationException: If ``gradient_clip_algorithm`` is invalid. - If ``track_grad_norm`` is not a positive number or inf. """ super().__init__() diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fdc49f787c..fc3793a07b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -42,9 +42,6 @@ from lightning.pytorch.cli import ( ) from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger -from lightning.pytorch.loggers.comet import _COMET_AVAILABLE -from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE -from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -1438,7 +1435,7 @@ def test_cli_logger_shorthand(): assert cli.trainer.logger is None -def _test_logger_init_args(logger_name, init, unresolved={}): +def _test_logger_init_args(logger_name, init, unresolved=None): cli_args = [f"--trainer.logger={logger_name}"] cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()] cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()] @@ -1454,48 +1451,38 @@ def _test_logger_init_args(logger_name, init, unresolved={}): assert data["dict_kwargs"] == unresolved -@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required") def test_comet_logger_init_args(): _test_logger_init_args( "CometLogger", - { - "save_dir": "comet", # Resolve from CometLogger.__init__ - "workspace": "comet", # Resolve from Comet{,Existing,Offline}Experiment.__init__ - }, + init={"save_dir": "comet"}, # Resolve from CometLogger.__init__ + unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__ ) -@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune is required") def test_neptune_logger_init_args(): _test_logger_init_args( "NeptuneLogger", - { - "name": "neptune", # Resolve from NeptuneLogger.__init__ - }, - { - "description": "neptune", # Unsupported resolving from neptune.internal.init.run.init_run - }, + init={"name": "neptune"}, # Resolve from NeptuneLogger.__init__ + unresolved={"description": "neptune"}, # Unsupported resolving from neptune.internal.init.run.init_run ) def test_tensorboard_logger_init_args(): _test_logger_init_args( "TensorBoardLogger", - { + init={ "save_dir": "tb", # Resolve from TensorBoardLogger.__init__ "comment": "tb", # Resolve from FabricTensorBoardLogger.experiment SummaryWriter local import }, + unresolved={}, ) -@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required") def test_wandb_logger_init_args(): _test_logger_init_args( "WandbLogger", - { - "save_dir": "wandb", # Resolve from WandbLogger.__init__ - "notes": "wandb", # Resolve from wandb.sdk.wandb_init.init - }, + init={"save_dir": "wandb"}, # Resolve from WandbLogger.__init__ + unresolved={"notes": "wandb"}, # Resolve from wandb.sdk.wandb_init.init )