Update CLI tests to no longer require 3rd party logger dependencies (#18899)

This commit is contained in:
Adrian Wälchli 2023-10-31 14:22:17 +01:00 committed by GitHub
parent 018a308269
commit 31b8777350
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 23 deletions

View File

@ -300,7 +300,6 @@ class Trainer:
MisconfigurationException:
If ``gradient_clip_algorithm`` is invalid.
If ``track_grad_norm`` is not a positive number or inf.
"""
super().__init__()

View File

@ -42,9 +42,6 @@ from lightning.pytorch.cli import (
)
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@ -1438,7 +1435,7 @@ def test_cli_logger_shorthand():
assert cli.trainer.logger is None
def _test_logger_init_args(logger_name, init, unresolved={}):
def _test_logger_init_args(logger_name, init, unresolved=None):
cli_args = [f"--trainer.logger={logger_name}"]
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
@ -1454,48 +1451,38 @@ def _test_logger_init_args(logger_name, init, unresolved={}):
assert data["dict_kwargs"] == unresolved
@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
def test_comet_logger_init_args():
_test_logger_init_args(
"CometLogger",
{
"save_dir": "comet", # Resolve from CometLogger.__init__
"workspace": "comet", # Resolve from Comet{,Existing,Offline}Experiment.__init__
},
init={"save_dir": "comet"}, # Resolve from CometLogger.__init__
unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__
)
@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune is required")
def test_neptune_logger_init_args():
_test_logger_init_args(
"NeptuneLogger",
{
"name": "neptune", # Resolve from NeptuneLogger.__init__
},
{
"description": "neptune", # Unsupported resolving from neptune.internal.init.run.init_run
},
init={"name": "neptune"}, # Resolve from NeptuneLogger.__init__
unresolved={"description": "neptune"}, # Unsupported resolving from neptune.internal.init.run.init_run
)
def test_tensorboard_logger_init_args():
_test_logger_init_args(
"TensorBoardLogger",
{
init={
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
"comment": "tb", # Resolve from FabricTensorBoardLogger.experiment SummaryWriter local import
},
unresolved={},
)
@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
def test_wandb_logger_init_args():
_test_logger_init_args(
"WandbLogger",
{
"save_dir": "wandb", # Resolve from WandbLogger.__init__
"notes": "wandb", # Resolve from wandb.sdk.wandb_init.init
},
init={"save_dir": "wandb"}, # Resolve from WandbLogger.__init__
unresolved={"notes": "wandb"}, # Resolve from wandb.sdk.wandb_init.init
)