diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e63484ca8..3bd7ae373c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -209,6 +209,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) + - Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) - Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index afa1238786..9dd8066f15 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_deprecation class LightningDataModule(CheckpointHooks, DataHooks): @@ -381,7 +381,7 @@ class LightningDataModule(CheckpointHooks, DataHooks): def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': obj = super().__new__(cls) # track `DataHooks` calls and run `prepare_data` only on rank zero - obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) + obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data) obj.setup = cls._track_data_hook_calls(obj, obj.setup) obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) return obj diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d4e1a3ff0e..66abba8d2c 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -34,6 +34,7 @@ from tests.helpers.utils import reset_seed @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): + model = BoringModel() dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm @@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank): # local rank = 0 (True) trainer.prepare_data_per_node = True + dm.random_full = None + dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is not None + # local rank = 1 (False) + dm.random_full = None + dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) + dm.random_full = None + dm._has_prepared_data = False trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is not None + # global rank = 1 (False) + dm.random_full = None + dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() + + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + # 2 dm # prepar per node = True # local rank = 0 (True)