[CLI] Fix `SaveConfigCallback` with DDP spawn (#12011)

This commit is contained in:
Carlos Mocholí 2022-02-28 14:27:42 +01:00 committed by GitHub
parent 01c31ae434
commit a9024ce870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 22 deletions

View File

@ -415,8 +415,6 @@ class SaveConfigCallback(Callback):
self.multifile = multifile self.multifile = multifile
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
# and we want to save before processes are spawned
log_dir = trainer.log_dir # this broadcasts the directory log_dir = trainer.log_dir # this broadcasts the directory
assert log_dir is not None assert log_dir is not None
config_path = os.path.join(log_dir, self.config_filename) config_path = os.path.join(log_dir, self.config_filename)
@ -437,7 +435,7 @@ class SaveConfigCallback(Callback):
# save the file on rank 0 # save the file on rank 0
if trainer.is_global_zero: if trainer.is_global_zero:
# save only on rank zero to avoid race conditions on DDP. # save only on rank zero to avoid race conditions.
# the `log_dir` needs to be created as we rely on the logger to do it usually # the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point # but it hasn't logged anything at this point
fs.makedirs(log_dir, exist_ok=True) fs.makedirs(log_dir, exist_ok=True)
@ -445,10 +443,6 @@ class SaveConfigCallback(Callback):
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
) )
def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
return self.__class__, (None, self.config, self.config_filename), {}
class LightningCLI: class LightningCLI:
"""Implementation of a configurable command line tool for pytorch-lightning.""" """Implementation of a configurable command line tool for pytorch-lightning."""

View File

@ -50,7 +50,7 @@ from pytorch_lightning.utilities.cli import (
SaveConfigCallback, SaveConfigCallback,
) )
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call from tests.helpers.utils import no_warning_call
@ -576,21 +576,17 @@ class EarlyExitTestModel(BoringModel):
raise MisconfigurationException("Error on fit start") raise MisconfigurationException("Error on fit start")
@RunIf(skip_windows=True)
@pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize("logger", (False, True))
@pytest.mark.parametrize( @pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
"trainer_kwargs", def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
( if _TORCH_GREATER_EQUAL_1_8:
# dict(strategy="ddp_spawn") from torch.multiprocessing import ProcessRaisedException
# dict(strategy="ddp") else:
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn ProcessRaisedException = Exception
# TODO revisit this test as it never worked with DDP or DDPSpawn
dict(strategy="single_device"),
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
),
)
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
MisconfigurationException, match=r"Error on fit start" (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
): ):
LightningCLI( LightningCLI(
EarlyExitTestModel, EarlyExitTestModel,
@ -599,7 +595,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
"logger": logger, "logger": logger,
"max_steps": 1, "max_steps": 1,
"max_epochs": 1, "max_epochs": 1,
**trainer_kwargs, "strategy": strategy,
"accelerator": "auto",
"devices": 1,
}, },
) )
if logger: if logger: