Update CLI tests to no longer require 3rd party logger dependencies (#18899)
This commit is contained in:
parent
018a308269
commit
31b8777350
|
@ -300,7 +300,6 @@ class Trainer:
|
|||
|
||||
MisconfigurationException:
|
||||
If ``gradient_clip_algorithm`` is invalid.
|
||||
If ``track_grad_norm`` is not a positive number or inf.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
|
|
@ -42,9 +42,6 @@ from lightning.pytorch.cli import (
|
|||
)
|
||||
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
||||
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
||||
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
|
||||
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
|
||||
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
@ -1438,7 +1435,7 @@ def test_cli_logger_shorthand():
|
|||
assert cli.trainer.logger is None
|
||||
|
||||
|
||||
def _test_logger_init_args(logger_name, init, unresolved={}):
|
||||
def _test_logger_init_args(logger_name, init, unresolved=None):
|
||||
cli_args = [f"--trainer.logger={logger_name}"]
|
||||
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
|
||||
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
|
||||
|
@ -1454,48 +1451,38 @@ def _test_logger_init_args(logger_name, init, unresolved={}):
|
|||
assert data["dict_kwargs"] == unresolved
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
|
||||
def test_comet_logger_init_args():
|
||||
_test_logger_init_args(
|
||||
"CometLogger",
|
||||
{
|
||||
"save_dir": "comet", # Resolve from CometLogger.__init__
|
||||
"workspace": "comet", # Resolve from Comet{,Existing,Offline}Experiment.__init__
|
||||
},
|
||||
init={"save_dir": "comet"}, # Resolve from CometLogger.__init__
|
||||
unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune is required")
|
||||
def test_neptune_logger_init_args():
|
||||
_test_logger_init_args(
|
||||
"NeptuneLogger",
|
||||
{
|
||||
"name": "neptune", # Resolve from NeptuneLogger.__init__
|
||||
},
|
||||
{
|
||||
"description": "neptune", # Unsupported resolving from neptune.internal.init.run.init_run
|
||||
},
|
||||
init={"name": "neptune"}, # Resolve from NeptuneLogger.__init__
|
||||
unresolved={"description": "neptune"}, # Unsupported resolving from neptune.internal.init.run.init_run
|
||||
)
|
||||
|
||||
|
||||
def test_tensorboard_logger_init_args():
|
||||
_test_logger_init_args(
|
||||
"TensorBoardLogger",
|
||||
{
|
||||
init={
|
||||
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
|
||||
"comment": "tb", # Resolve from FabricTensorBoardLogger.experiment SummaryWriter local import
|
||||
},
|
||||
unresolved={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
|
||||
def test_wandb_logger_init_args():
|
||||
_test_logger_init_args(
|
||||
"WandbLogger",
|
||||
{
|
||||
"save_dir": "wandb", # Resolve from WandbLogger.__init__
|
||||
"notes": "wandb", # Resolve from wandb.sdk.wandb_init.init
|
||||
},
|
||||
init={"save_dir": "wandb"}, # Resolve from WandbLogger.__init__
|
||||
unresolved={"notes": "wandb"}, # Resolve from wandb.sdk.wandb_init.init
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue