refactor slurm_job_id (#10622)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
338f3cf636
commit
6fc7c54c3a
|
@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
|
||||
|
||||
|
||||
-
|
||||
- Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
||||
|
@ -37,6 +38,21 @@ class SLURMEnvironment(ClusterEnvironment):
|
|||
def creates_processes_externally(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def job_id() -> Optional[int]:
|
||||
job_id = os.environ.get("SLURM_JOB_ID")
|
||||
if job_id:
|
||||
try:
|
||||
job_id = int(job_id)
|
||||
except ValueError:
|
||||
job_id = None
|
||||
|
||||
# in interactive mode, don't make logs use the same job id
|
||||
in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash"
|
||||
if in_slurm_interactive_mode:
|
||||
job_id = None
|
||||
return job_id
|
||||
|
||||
@property
|
||||
def main_address(self) -> str:
|
||||
# figure out the root node addr
|
||||
|
|
|
@ -18,6 +18,7 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities import DeviceType, memory
|
||||
|
@ -81,7 +82,7 @@ class LoggerConnector:
|
|||
# default logger
|
||||
self.trainer.logger = (
|
||||
TensorBoardLogger(
|
||||
save_dir=self.trainer.default_root_dir, version=self.trainer.slurm_job_id, name="lightning_logs"
|
||||
save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs"
|
||||
)
|
||||
if logger
|
||||
else None
|
||||
|
|
|
@ -39,6 +39,7 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingE
|
|||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.profiler import (
|
||||
AdvancedProfiler,
|
||||
BaseProfiler,
|
||||
|
@ -1725,18 +1726,8 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def slurm_job_id(self) -> Optional[int]:
|
||||
job_id = os.environ.get("SLURM_JOB_ID")
|
||||
if job_id:
|
||||
try:
|
||||
job_id = int(job_id)
|
||||
except ValueError:
|
||||
job_id = None
|
||||
|
||||
# in interactive mode, don't make logs use the same job id
|
||||
in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash"
|
||||
if in_slurm_interactive_mode:
|
||||
job_id = None
|
||||
return job_id
|
||||
rank_zero_deprecation("Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0.")
|
||||
return SLURMEnvironment.job_id()
|
||||
|
||||
@property
|
||||
def lightning_optimizers(self) -> List[LightningOptimizer]:
|
||||
|
|
|
@ -378,6 +378,12 @@ def test_v1_7_0_trainer_log_gpu_memory(tmpdir):
|
|||
_ = Trainer(log_gpu_memory="min_max")
|
||||
|
||||
|
||||
def test_v1_7_0_deprecated_slurm_job_id():
|
||||
trainer = Trainer()
|
||||
with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."):
|
||||
trainer.slurm_job_id
|
||||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir):
|
||||
with pytest.deprecated_call(match="The `GPUStatsMonitor` callback was deprecated in v1.5"):
|
||||
|
|
Loading…
Reference in New Issue