Skip hanging spawn tests (#10838)
* Skip hanging spawn tests * Docstring fix * Add back to TPU spawn
This commit is contained in:
parent
38ed26ec5a
commit
8e1b9b306c
|
@ -25,7 +25,6 @@ from torch.nn import Module
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
@ -149,17 +148,14 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
return {"nprocs": self.num_processes}
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> None:
|
||||
self._clean_logger(trainer)
|
||||
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
|
||||
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
|
||||
trainer.optimizers = []
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
self._clean_logger(trainer)
|
||||
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
self._clean_logger(trainer)
|
||||
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
|
||||
|
||||
def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
|
||||
|
@ -420,16 +416,3 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.lightning_module.cpu()
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _clean_logger(trainer: "pl.Trainer") -> None:
|
||||
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
|
||||
for logger in loggers:
|
||||
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
|
||||
rank_zero_warn(
|
||||
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
|
||||
)
|
||||
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
|
||||
# we want to make sure these are closed before we spawn our own threads.
|
||||
# assuming nothing else references the experiment object, python should instantly `__del__` it.
|
||||
logger._experiment = None
|
||||
|
|
|
@ -24,6 +24,7 @@ from torch.nn import Module
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||
|
@ -304,8 +305,17 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
# todo: precision pluging is call in accelerator setup and should be moved
|
||||
if "XLA_USE_BF16" in os.environ:
|
||||
del os.environ["XLA_USE_BF16"]
|
||||
self._clean_logger(trainer)
|
||||
return super().start_training(trainer)
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_evaluating(trainer)
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_predicting(trainer)
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
|
@ -381,3 +391,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
@checkpoint_io.setter
|
||||
def checkpoint_io(self, plugin: CheckpointIO) -> None:
|
||||
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
|
||||
|
||||
@staticmethod
|
||||
def _clean_logger(trainer: "pl.Trainer") -> None:
|
||||
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
|
||||
for logger in loggers:
|
||||
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
|
||||
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
|
||||
# we want to make sure these are closed before we spawn our own threads.
|
||||
# assuming nothing else references the experiment object, python should instantly `__del__` it.
|
||||
logger._experiment = None
|
||||
|
|
|
@ -71,6 +71,7 @@ class RunIf:
|
|||
deepspeed: bool = False,
|
||||
rich: bool = False,
|
||||
skip_49370: bool = False,
|
||||
skip_hanging_spawn: bool = False,
|
||||
omegaconf: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -94,6 +95,7 @@ class RunIf:
|
|||
deepspeed: Require that microsoft/DeepSpeed is installed.
|
||||
rich: Require that willmcgugan/rich is installed.
|
||||
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
|
||||
skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn.
|
||||
omegaconf: Require that omry/omegaconf is installed.
|
||||
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
|
||||
"""
|
||||
|
@ -180,6 +182,15 @@ class RunIf:
|
|||
conditions.append(ge_3_9 and old_torch)
|
||||
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
|
||||
|
||||
if skip_hanging_spawn:
|
||||
# strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work
|
||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
ge_3_8 = Version(py_version) >= Version("3.8")
|
||||
torch_version = get_distribution("torch").version
|
||||
old_torch = Version(torch_version) < Version("1.9")
|
||||
conditions.append(ge_3_8 and old_torch)
|
||||
reasons.append("Impacted by hanging DDP spawn")
|
||||
|
||||
if omegaconf:
|
||||
conditions.append(not _OMEGACONF_AVAILABLE)
|
||||
reasons.append("omegaconf")
|
||||
|
|
|
@ -329,7 +329,7 @@ class RankZeroLoggerCheck(Callback):
|
|||
assert pl_module.logger.experiment.something(foo="bar") is None
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, skip_49370=True)
|
||||
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
|
||||
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
|
||||
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
|
||||
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
|
||||
|
|
|
@ -24,7 +24,6 @@ import yaml
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -335,17 +334,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
|
|||
assert logger.version == 0
|
||||
|
||||
assert "Missing logger folder:" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_list", [False, True])
|
||||
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
|
||||
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
|
||||
assert tensorboard_logger._experiment is None
|
||||
tensorboard_logger.experiment # this property access will create the experiment
|
||||
assert tensorboard_logger._experiment is not None
|
||||
logger = [tensorboard_logger] if use_list else tensorboard_logger
|
||||
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
|
||||
trainer.training_type_plugin._clean_logger(trainer)
|
||||
if use_list:
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
assert tensorboard_logger._experiment is None
|
||||
|
|
|
@ -20,6 +20,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
|
@ -102,3 +103,18 @@ def test_model_tpu_one_core():
|
|||
model = BoringModelTPU()
|
||||
trainer.fit(model)
|
||||
assert "PT_XLA_DEBUG" not in os.environ
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@pytest.mark.parametrize("use_list", [False, True])
|
||||
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
|
||||
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
|
||||
assert tensorboard_logger._experiment is None
|
||||
tensorboard_logger.experiment # this property access will create the experiment
|
||||
assert tensorboard_logger._experiment is not None
|
||||
logger = [tensorboard_logger] if use_list else tensorboard_logger
|
||||
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
|
||||
trainer.training_type_plugin._clean_logger(trainer)
|
||||
if use_list:
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
assert tensorboard_logger._experiment is None
|
||||
|
|
|
@ -54,7 +54,7 @@ def _test_all_gather_ddp(rank, world_size):
|
|||
assert torch.allclose(grad2, tensor2.grad)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, skip_49370=True)
|
||||
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
|
||||
def test_all_gather_ddp_spawn():
|
||||
world_size = 3
|
||||
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)
|
||||
|
|
Loading…
Reference in New Issue