diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index fad41f1230..b958d4808e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -25,7 +25,6 @@ from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -149,17 +148,14 @@ class DDPSpawnPlugin(ParallelPlugin): return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def start_predicting(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]: @@ -420,16 +416,3 @@ class DDPSpawnPlugin(ParallelPlugin): self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() - - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - rank_zero_warn( - "When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`." - ) - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5ef8a46d71..a7258fc712 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -24,6 +24,7 @@ from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO @@ -304,8 +305,17 @@ class TPUSpawnPlugin(DDPSpawnPlugin): # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] + self._clean_logger(trainer) return super().start_training(trainer) + def start_evaluating(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_evaluating(trainer) + + def start_predicting(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_predicting(trainer) + def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -381,3 +391,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin): @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") + + @staticmethod + def _clean_logger(trainer: "pl.Trainer") -> None: + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] + for logger in loggers: + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. + # we want to make sure these are closed before we spawn our own threads. + # assuming nothing else references the experiment object, python should instantly `__del__` it. + logger._experiment = None diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 8a9f707e69..179535a63d 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -71,6 +71,7 @@ class RunIf: deepspeed: bool = False, rich: bool = False, skip_49370: bool = False, + skip_hanging_spawn: bool = False, omegaconf: bool = False, **kwargs, ): @@ -94,6 +95,7 @@ class RunIf: deepspeed: Require that microsoft/DeepSpeed is installed. rich: Require that willmcgugan/rich is installed. skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370. + skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn. omegaconf: Require that omry/omegaconf is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ @@ -180,6 +182,15 @@ class RunIf: conditions.append(ge_3_9 and old_torch) reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370") + if skip_hanging_spawn: + # strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ge_3_8 = Version(py_version) >= Version("3.8") + torch_version = get_distribution("torch").version + old_torch = Version(torch_version) < Version("1.9") + conditions.append(ge_3_8 and old_torch) + reasons.append("Impacted by hanging DDP spawn") + if omegaconf: conditions.append(not _OMEGACONF_AVAILABLE) reasons.append("omegaconf") diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 803a13cbb1..6b3c547ed3 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -329,7 +329,7 @@ class RankZeroLoggerCheck(Callback): assert pl_module.logger.experiment.something(foo="bar") is None -@RunIf(skip_windows=True, skip_49370=True) +@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True) @pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger]) def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): """Test that loggers get replaced by dummy loggers on global rank > 0.""" diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index d0119b3e86..a261b78cc0 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -24,7 +24,6 @@ import yaml from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.base import LoggerCollection from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -335,17 +334,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog): assert logger.version == 0 assert "Missing logger folder:" in caplog.text - - -@pytest.mark.parametrize("use_list", [False, True]) -def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir): - tensorboard_logger = TensorBoardLogger(save_dir=tmpdir) - assert tensorboard_logger._experiment is None - tensorboard_logger.experiment # this property access will create the experiment - assert tensorboard_logger._experiment is not None - logger = [tensorboard_logger] if use_list else tensorboard_logger - trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger) - trainer.training_type_plugin._clean_logger(trainer) - if use_list: - assert isinstance(trainer.logger, LoggerCollection) - assert tensorboard_logger._experiment is None diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 3f4ff354e3..5f4abf560d 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -20,6 +20,7 @@ import torch from torch.utils.data import DataLoader from pytorch_lightning import Trainer +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.plugins.training_type import TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -102,3 +103,18 @@ def test_model_tpu_one_core(): model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ + + +@RunIf(tpu=True) +@pytest.mark.parametrize("use_list", [False, True]) +def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir): + tensorboard_logger = TensorBoardLogger(save_dir=tmpdir) + assert tensorboard_logger._experiment is None + tensorboard_logger.experiment # this property access will create the experiment + assert tensorboard_logger._experiment is not None + logger = [tensorboard_logger] if use_list else tensorboard_logger + trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger) + trainer.training_type_plugin._clean_logger(trainer) + if use_list: + assert isinstance(trainer.logger, LoggerCollection) + assert tensorboard_logger._experiment is None diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index b7dfd5cbc3..70d4528f03 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -54,7 +54,7 @@ def _test_all_gather_ddp(rank, world_size): assert torch.allclose(grad2, tensor2.grad) -@RunIf(skip_windows=True, skip_49370=True) +@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True) def test_all_gather_ddp_spawn(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)