Remove rank_zero_only on DataModule prepare_data (#7945)
Signed-off-by: Max Ehrlich <max.ehr@gmail.com>
This commit is contained in:
parent
96433d03ea
commit
6856ccedfd
|
@ -209,6 +209,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
|
||||||
|
|
||||||
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
|
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
|
||||||
|
|
||||||
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
|
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
|
||||||
|
|
|
@ -21,7 +21,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
|
||||||
|
|
||||||
|
|
||||||
class LightningDataModule(CheckpointHooks, DataHooks):
|
class LightningDataModule(CheckpointHooks, DataHooks):
|
||||||
|
@ -381,7 +381,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
|
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
|
||||||
obj = super().__new__(cls)
|
obj = super().__new__(cls)
|
||||||
# track `DataHooks` calls and run `prepare_data` only on rank zero
|
# track `DataHooks` calls and run `prepare_data` only on rank zero
|
||||||
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
|
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
|
||||||
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
|
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
|
||||||
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
|
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -34,6 +34,7 @@ from tests.helpers.utils import reset_seed
|
||||||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
||||||
def test_can_prepare_data(local_rank, node_rank):
|
def test_can_prepare_data(local_rank, node_rank):
|
||||||
|
|
||||||
|
model = BoringModel()
|
||||||
dm = BoringDataModule()
|
dm = BoringDataModule()
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
trainer.datamodule = dm
|
trainer.datamodule = dm
|
||||||
|
@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
|
||||||
# local rank = 0 (True)
|
# local rank = 0 (True)
|
||||||
trainer.prepare_data_per_node = True
|
trainer.prepare_data_per_node = True
|
||||||
|
|
||||||
|
dm.random_full = None
|
||||||
|
dm._has_prepared_data = False
|
||||||
local_rank.return_value = 0
|
local_rank.return_value = 0
|
||||||
assert trainer.local_rank == 0
|
assert trainer.local_rank == 0
|
||||||
assert trainer.data_connector.can_prepare_data()
|
assert trainer.data_connector.can_prepare_data()
|
||||||
|
|
||||||
|
trainer.data_connector.prepare_data(model)
|
||||||
|
assert dm.random_full is not None
|
||||||
|
|
||||||
# local rank = 1 (False)
|
# local rank = 1 (False)
|
||||||
|
dm.random_full = None
|
||||||
|
dm._has_prepared_data = False
|
||||||
local_rank.return_value = 1
|
local_rank.return_value = 1
|
||||||
assert trainer.local_rank == 1
|
assert trainer.local_rank == 1
|
||||||
assert not trainer.data_connector.can_prepare_data()
|
assert not trainer.data_connector.can_prepare_data()
|
||||||
|
|
||||||
|
trainer.data_connector.prepare_data(model)
|
||||||
|
assert dm.random_full is None
|
||||||
|
|
||||||
# prepare_data_per_node = False (prepare across all nodes)
|
# prepare_data_per_node = False (prepare across all nodes)
|
||||||
# global rank = 0 (True)
|
# global rank = 0 (True)
|
||||||
|
dm.random_full = None
|
||||||
|
dm._has_prepared_data = False
|
||||||
trainer.prepare_data_per_node = False
|
trainer.prepare_data_per_node = False
|
||||||
node_rank.return_value = 0
|
node_rank.return_value = 0
|
||||||
local_rank.return_value = 0
|
local_rank.return_value = 0
|
||||||
assert trainer.data_connector.can_prepare_data()
|
assert trainer.data_connector.can_prepare_data()
|
||||||
|
|
||||||
|
trainer.data_connector.prepare_data(model)
|
||||||
|
assert dm.random_full is not None
|
||||||
|
|
||||||
# global rank = 1 (False)
|
# global rank = 1 (False)
|
||||||
|
dm.random_full = None
|
||||||
|
dm._has_prepared_data = False
|
||||||
node_rank.return_value = 1
|
node_rank.return_value = 1
|
||||||
local_rank.return_value = 0
|
local_rank.return_value = 0
|
||||||
assert not trainer.data_connector.can_prepare_data()
|
assert not trainer.data_connector.can_prepare_data()
|
||||||
|
|
||||||
|
trainer.data_connector.prepare_data(model)
|
||||||
|
assert dm.random_full is None
|
||||||
|
|
||||||
node_rank.return_value = 0
|
node_rank.return_value = 0
|
||||||
local_rank.return_value = 1
|
local_rank.return_value = 1
|
||||||
assert not trainer.data_connector.can_prepare_data()
|
assert not trainer.data_connector.can_prepare_data()
|
||||||
|
|
||||||
|
trainer.data_connector.prepare_data(model)
|
||||||
|
assert dm.random_full is None
|
||||||
|
|
||||||
# 2 dm
|
# 2 dm
|
||||||
# prepar per node = True
|
# prepar per node = True
|
||||||
# local rank = 0 (True)
|
# local rank = 0 (True)
|
||||||
|
|
Loading…
Reference in New Issue