Skip hanging spawn tests (#10838)

* Skip hanging spawn tests

* Docstring fix

* Add back to TPU spawn
This commit is contained in:
Carlos Mocholí 2021-11-30 19:36:12 +01:00 committed by GitHub
parent 38ed26ec5a
commit 8e1b9b306c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 49 additions and 34 deletions

View File

@ -25,7 +25,6 @@ from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@ -149,17 +148,14 @@ class DDPSpawnPlugin(ParallelPlugin):
return {"nprocs": self.num_processes}
def start_training(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
@ -420,16 +416,3 @@ class DDPSpawnPlugin(ParallelPlugin):
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn(
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
)
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None

View File

@ -24,6 +24,7 @@ from torch.nn import Module
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@ -304,8 +305,17 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._clean_logger(trainer)
return super().start_training(trainer)
def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
return super().start_evaluating(trainer)
def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
return super().start_predicting(trainer)
def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
@ -381,3 +391,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
@checkpoint_io.setter
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None

View File

@ -71,6 +71,7 @@ class RunIf:
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
skip_hanging_spawn: bool = False,
omegaconf: bool = False,
**kwargs,
):
@ -94,6 +95,7 @@ class RunIf:
deepspeed: Require that microsoft/DeepSpeed is installed.
rich: Require that willmcgugan/rich is installed.
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn.
omegaconf: Require that omry/omegaconf is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
@ -180,6 +182,15 @@ class RunIf:
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
if skip_hanging_spawn:
# strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
ge_3_8 = Version(py_version) >= Version("3.8")
torch_version = get_distribution("torch").version
old_torch = Version(torch_version) < Version("1.9")
conditions.append(ge_3_8 and old_torch)
reasons.append("Impacted by hanging DDP spawn")
if omegaconf:
conditions.append(not _OMEGACONF_AVAILABLE)
reasons.append("omegaconf")

View File

@ -329,7 +329,7 @@ class RankZeroLoggerCheck(Callback):
assert pl_module.logger.experiment.something(foo="bar") is None
@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""

View File

@ -24,7 +24,6 @@ import yaml
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@ -335,17 +334,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
assert logger.version == 0
assert "Missing logger folder:" in caplog.text
@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
assert tensorboard_logger._experiment is None
tensorboard_logger.experiment # this property access will create the experiment
assert tensorboard_logger._experiment is not None
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None

View File

@ -20,6 +20,7 @@ import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
@ -102,3 +103,18 @@ def test_model_tpu_one_core():
model = BoringModelTPU()
trainer.fit(model)
assert "PT_XLA_DEBUG" not in os.environ
@RunIf(tpu=True)
@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
assert tensorboard_logger._experiment is None
tensorboard_logger.experiment # this property access will create the experiment
assert tensorboard_logger._experiment is not None
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None

View File

@ -54,7 +54,7 @@ def _test_all_gather_ddp(rank, world_size):
assert torch.allclose(grad2, tensor2.grad)
@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
def test_all_gather_ddp_spawn():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)