diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fc3178b44..46a58ccfb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) -- +- Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) - diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 53fa4c2a83..ad657e1e19 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -15,6 +15,7 @@ import logging import os import re +from typing import Optional from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -37,6 +38,21 @@ class SLURMEnvironment(ClusterEnvironment): def creates_processes_externally(self) -> bool: return True + @staticmethod + def job_id() -> Optional[int]: + job_id = os.environ.get("SLURM_JOB_ID") + if job_id: + try: + job_id = int(job_id) + except ValueError: + job_id = None + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" + if in_slurm_interactive_mode: + job_id = None + return job_id + @property def main_address(self) -> str: # figure out the root node addr diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d970d98c60..b98f13138b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -18,6 +18,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType, memory @@ -81,7 +82,7 @@ class LoggerConnector: # default logger self.trainer.logger = ( TensorBoardLogger( - save_dir=self.trainer.default_root_dir, version=self.trainer.slurm_job_id, name="lightning_logs" + save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs" ) if logger else None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cb446632cc..1ccdb9ecae 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -39,6 +39,7 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingE from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -1725,18 +1726,8 @@ class Trainer( @property def slurm_job_id(self) -> Optional[int]: - job_id = os.environ.get("SLURM_JOB_ID") - if job_id: - try: - job_id = int(job_id) - except ValueError: - job_id = None - - # in interactive mode, don't make logs use the same job id - in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" - if in_slurm_interactive_mode: - job_id = None - return job_id + rank_zero_deprecation("Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0.") + return SLURMEnvironment.job_id() @property def lightning_optimizers(self) -> List[LightningOptimizer]: diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index c801696196..9c0c12f981 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -378,6 +378,12 @@ def test_v1_7_0_trainer_log_gpu_memory(tmpdir): _ = Trainer(log_gpu_memory="min_max") +def test_v1_7_0_deprecated_slurm_job_id(): + trainer = Trainer() + with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."): + trainer.slurm_job_id + + @RunIf(min_gpus=1) def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir): with pytest.deprecated_call(match="The `GPUStatsMonitor` callback was deprecated in v1.5"):