[CLI] Fix `SaveConfigCallback` with DDP spawn (#12011)
This commit is contained in:
parent
01c31ae434
commit
a9024ce870
|
@ -415,8 +415,6 @@ class SaveConfigCallback(Callback):
|
||||||
self.multifile = multifile
|
self.multifile = multifile
|
||||||
|
|
||||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||||
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
|
|
||||||
# and we want to save before processes are spawned
|
|
||||||
log_dir = trainer.log_dir # this broadcasts the directory
|
log_dir = trainer.log_dir # this broadcasts the directory
|
||||||
assert log_dir is not None
|
assert log_dir is not None
|
||||||
config_path = os.path.join(log_dir, self.config_filename)
|
config_path = os.path.join(log_dir, self.config_filename)
|
||||||
|
@ -437,7 +435,7 @@ class SaveConfigCallback(Callback):
|
||||||
|
|
||||||
# save the file on rank 0
|
# save the file on rank 0
|
||||||
if trainer.is_global_zero:
|
if trainer.is_global_zero:
|
||||||
# save only on rank zero to avoid race conditions on DDP.
|
# save only on rank zero to avoid race conditions.
|
||||||
# the `log_dir` needs to be created as we rely on the logger to do it usually
|
# the `log_dir` needs to be created as we rely on the logger to do it usually
|
||||||
# but it hasn't logged anything at this point
|
# but it hasn't logged anything at this point
|
||||||
fs.makedirs(log_dir, exist_ok=True)
|
fs.makedirs(log_dir, exist_ok=True)
|
||||||
|
@ -445,10 +443,6 @@ class SaveConfigCallback(Callback):
|
||||||
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
|
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
|
||||||
)
|
)
|
||||||
|
|
||||||
def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
|
|
||||||
# `ArgumentParser` is un-pickleable. Drop it
|
|
||||||
return self.__class__, (None, self.config, self.config_filename), {}
|
|
||||||
|
|
||||||
|
|
||||||
class LightningCLI:
|
class LightningCLI:
|
||||||
"""Implementation of a configurable command line tool for pytorch-lightning."""
|
"""Implementation of a configurable command line tool for pytorch-lightning."""
|
||||||
|
|
|
@ -50,7 +50,7 @@ from pytorch_lightning.utilities.cli import (
|
||||||
SaveConfigCallback,
|
SaveConfigCallback,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCHVISION_AVAILABLE
|
||||||
from tests.helpers import BoringDataModule, BoringModel
|
from tests.helpers import BoringDataModule, BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
from tests.helpers.utils import no_warning_call
|
from tests.helpers.utils import no_warning_call
|
||||||
|
@ -576,21 +576,17 @@ class EarlyExitTestModel(BoringModel):
|
||||||
raise MisconfigurationException("Error on fit start")
|
raise MisconfigurationException("Error on fit start")
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(skip_windows=True)
|
||||||
@pytest.mark.parametrize("logger", (False, True))
|
@pytest.mark.parametrize("logger", (False, True))
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
|
||||||
"trainer_kwargs",
|
def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
|
||||||
(
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
# dict(strategy="ddp_spawn")
|
from torch.multiprocessing import ProcessRaisedException
|
||||||
# dict(strategy="ddp")
|
else:
|
||||||
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
|
ProcessRaisedException = Exception
|
||||||
# TODO revisit this test as it never worked with DDP or DDPSpawn
|
|
||||||
dict(strategy="single_device"),
|
|
||||||
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
|
|
||||||
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
|
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
|
||||||
MisconfigurationException, match=r"Error on fit start"
|
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
|
||||||
):
|
):
|
||||||
LightningCLI(
|
LightningCLI(
|
||||||
EarlyExitTestModel,
|
EarlyExitTestModel,
|
||||||
|
@ -599,7 +595,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
|
||||||
"logger": logger,
|
"logger": logger,
|
||||||
"max_steps": 1,
|
"max_steps": 1,
|
||||||
"max_epochs": 1,
|
"max_epochs": 1,
|
||||||
**trainer_kwargs,
|
"strategy": strategy,
|
||||||
|
"accelerator": "auto",
|
||||||
|
"devices": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if logger:
|
if logger:
|
||||||
|
|
Loading…
Reference in New Issue