Remove rank_zero_only on DataModule prepare_data (#7945)
Signed-off-by: Max Ehrlich <max.ehr@gmail.com>
This commit is contained in:
parent
96433d03ea
commit
6856ccedfd
|
@ -209,6 +209,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
|
||||
|
||||
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
|
||||
|
||||
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
|
||||
|
|
|
@ -21,7 +21,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|||
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
|
||||
|
||||
|
||||
class LightningDataModule(CheckpointHooks, DataHooks):
|
||||
|
@ -381,7 +381,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
|
||||
obj = super().__new__(cls)
|
||||
# track `DataHooks` calls and run `prepare_data` only on rank zero
|
||||
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
|
||||
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
|
||||
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
|
||||
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
|
||||
return obj
|
||||
|
|
|
@ -34,6 +34,7 @@ from tests.helpers.utils import reset_seed
|
|||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
||||
def test_can_prepare_data(local_rank, node_rank):
|
||||
|
||||
model = BoringModel()
|
||||
dm = BoringDataModule()
|
||||
trainer = Trainer()
|
||||
trainer.datamodule = dm
|
||||
|
@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
|
|||
# local rank = 0 (True)
|
||||
trainer.prepare_data_per_node = True
|
||||
|
||||
dm.random_full = None
|
||||
dm._has_prepared_data = False
|
||||
local_rank.return_value = 0
|
||||
assert trainer.local_rank == 0
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
trainer.data_connector.prepare_data(model)
|
||||
assert dm.random_full is not None
|
||||
|
||||
# local rank = 1 (False)
|
||||
dm.random_full = None
|
||||
dm._has_prepared_data = False
|
||||
local_rank.return_value = 1
|
||||
assert trainer.local_rank == 1
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
trainer.data_connector.prepare_data(model)
|
||||
assert dm.random_full is None
|
||||
|
||||
# prepare_data_per_node = False (prepare across all nodes)
|
||||
# global rank = 0 (True)
|
||||
dm.random_full = None
|
||||
dm._has_prepared_data = False
|
||||
trainer.prepare_data_per_node = False
|
||||
node_rank.return_value = 0
|
||||
local_rank.return_value = 0
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
trainer.data_connector.prepare_data(model)
|
||||
assert dm.random_full is not None
|
||||
|
||||
# global rank = 1 (False)
|
||||
dm.random_full = None
|
||||
dm._has_prepared_data = False
|
||||
node_rank.return_value = 1
|
||||
local_rank.return_value = 0
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
trainer.data_connector.prepare_data(model)
|
||||
assert dm.random_full is None
|
||||
|
||||
node_rank.return_value = 0
|
||||
local_rank.return_value = 1
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
trainer.data_connector.prepare_data(model)
|
||||
assert dm.random_full is None
|
||||
|
||||
# 2 dm
|
||||
# prepar per node = True
|
||||
# local rank = 0 (True)
|
||||
|
|
Loading…
Reference in New Issue