Remove rank_zero_only on DataModule prepare_data (#7945)

Signed-off-by: Max Ehrlich <max.ehr@gmail.com>
This commit is contained in:
Max Ehrlich 2021-06-12 06:50:29 -04:00 committed by GitHub
parent 96433d03ea
commit 6856ccedfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 2 deletions

View File

@ -209,6 +209,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed ### Fixed
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) - Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) - Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))

View File

@ -21,7 +21,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_deprecation
class LightningDataModule(CheckpointHooks, DataHooks): class LightningDataModule(CheckpointHooks, DataHooks):
@ -381,7 +381,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls) obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero # track `DataHooks` calls and run `prepare_data` only on rank zero
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
obj.setup = cls._track_data_hook_calls(obj, obj.setup) obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
return obj return obj

View File

@ -34,6 +34,7 @@ from tests.helpers.utils import reset_seed
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank): def test_can_prepare_data(local_rank, node_rank):
model = BoringModel()
dm = BoringDataModule() dm = BoringDataModule()
trainer = Trainer() trainer = Trainer()
trainer.datamodule = dm trainer.datamodule = dm
@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
# local rank = 0 (True) # local rank = 0 (True)
trainer.prepare_data_per_node = True trainer.prepare_data_per_node = True
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 0 local_rank.return_value = 0
assert trainer.local_rank == 0 assert trainer.local_rank == 0
assert trainer.data_connector.can_prepare_data() assert trainer.data_connector.can_prepare_data()
trainer.data_connector.prepare_data(model)
assert dm.random_full is not None
# local rank = 1 (False) # local rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 1 local_rank.return_value = 1
assert trainer.local_rank == 1 assert trainer.local_rank == 1
assert not trainer.data_connector.can_prepare_data() assert not trainer.data_connector.can_prepare_data()
trainer.data_connector.prepare_data(model)
assert dm.random_full is None
# prepare_data_per_node = False (prepare across all nodes) # prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True) # global rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
trainer.prepare_data_per_node = False trainer.prepare_data_per_node = False
node_rank.return_value = 0 node_rank.return_value = 0
local_rank.return_value = 0 local_rank.return_value = 0
assert trainer.data_connector.can_prepare_data() assert trainer.data_connector.can_prepare_data()
trainer.data_connector.prepare_data(model)
assert dm.random_full is not None
# global rank = 1 (False) # global rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
node_rank.return_value = 1 node_rank.return_value = 1
local_rank.return_value = 0 local_rank.return_value = 0
assert not trainer.data_connector.can_prepare_data() assert not trainer.data_connector.can_prepare_data()
trainer.data_connector.prepare_data(model)
assert dm.random_full is None
node_rank.return_value = 0 node_rank.return_value = 0
local_rank.return_value = 1 local_rank.return_value = 1
assert not trainer.data_connector.can_prepare_data() assert not trainer.data_connector.can_prepare_data()
trainer.data_connector.prepare_data(model)
assert dm.random_full is None
# 2 dm # 2 dm
# prepar per node = True # prepar per node = True
# local rank = 0 (True) # local rank = 0 (True)