diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8dcec7db41..eff633558b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a formatting issue when the filename in `ModelCheckpoint` contained metrics that were substrings of each other ([#17610](https://github.com/Lightning-AI/lightning/pull/17610)) +- Fixed `WandbLogger` ignoring the `WANDB_PROJECT` environment variable ([#16222](https://github.com/Lightning-AI/lightning/pull/16222)) + + ## [2.0.1.post0] - 2023-04-11 ### Fixed diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 10570bbbc2..d5f928bd8e 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -260,7 +260,8 @@ class WandbLogger(Logger): dir: Same as save_dir. id: Same as version. anonymous: Enables or explicitly disables anonymous logging. - project: The name of the project to which this run will belong. + project: The name of the project to which this run will belong. If not set, the environment variable + `WANDB_PROJECT` will be used as a fallback. If both are not set, it defaults to ``'lightning_logs'``. log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.ModelCheckpoint` as W&B artifacts. `latest` and `best` aliases are automatically set. @@ -293,7 +294,7 @@ class WandbLogger(Logger): dir: Optional[_PATH] = None, id: Optional[str] = None, anonymous: Optional[bool] = None, - project: str = "lightning_logs", + project: Optional[str] = None, log_model: Union[str, bool] = False, experiment: Union[Run, RunDisabled, None] = None, prefix: str = "", @@ -334,6 +335,8 @@ class WandbLogger(Logger): elif dir is not None: dir = os.fspath(dir) + project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") + # set wandb init arguments self._wandb_init: Dict[str, Any] = { "name": name, diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index b90581d443..e221326547 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -29,10 +29,20 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException @mock.patch("lightning.pytorch.loggers.wandb.Run", new=mock.Mock) @mock.patch("lightning.pytorch.loggers.wandb.wandb") def test_wandb_project_name(*_): - logger = WandbLogger() + with mock.patch.dict(os.environ, {}): + logger = WandbLogger() assert logger.name == "lightning_logs" - logger = WandbLogger(project="project") + with mock.patch.dict(os.environ, {}): + logger = WandbLogger(project="project") + assert logger.name == "project" + + with mock.patch.dict(os.environ, {"WANDB_PROJECT": "env_project"}): + logger = WandbLogger() + assert logger.name == "env_project" + + with mock.patch.dict(os.environ, {"WANDB_PROJECT": "env_project"}): + logger = WandbLogger(project="project") assert logger.name == "project"