Fix `prepare_data` implementation in `BoringDataModule` (#10915)
This commit is contained in:
parent
b7331d80dc
commit
553d429ecb
|
@ -16,7 +16,7 @@ from argparse import ArgumentParser, Namespace
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
from unittest.mock import call, PropertyMock
|
||||
from unittest.mock import call, Mock, PropertyMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -40,51 +40,52 @@ if _OMEGACONF_AVAILABLE:
|
|||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
|
||||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
||||
def test_can_prepare_data(local_rank, node_rank):
|
||||
dm = BoringDataModule()
|
||||
dm = Mock(spec=LightningDataModule)
|
||||
dm.prepare_data_per_node = True
|
||||
trainer = Trainer()
|
||||
trainer.datamodule = dm
|
||||
|
||||
# 1 no DM
|
||||
# prepare_data_per_node = True
|
||||
# local rank = 0 (True)
|
||||
dm.random_full = None
|
||||
dm.prepare_data.assert_not_called()
|
||||
local_rank.return_value = 0
|
||||
assert trainer.local_rank == 0
|
||||
|
||||
trainer._data_connector.prepare_data()
|
||||
assert dm.random_full is not None
|
||||
dm.prepare_data.assert_called_once()
|
||||
|
||||
# local rank = 1 (False)
|
||||
dm.random_full = None
|
||||
dm.reset_mock()
|
||||
local_rank.return_value = 1
|
||||
assert trainer.local_rank == 1
|
||||
|
||||
trainer._data_connector.prepare_data()
|
||||
assert dm.random_full is None
|
||||
dm.prepare_data.assert_not_called()
|
||||
|
||||
# prepare_data_per_node = False (prepare across all nodes)
|
||||
# global rank = 0 (True)
|
||||
dm.random_full = None
|
||||
dm.reset_mock()
|
||||
dm.prepare_data_per_node = False
|
||||
node_rank.return_value = 0
|
||||
local_rank.return_value = 0
|
||||
|
||||
trainer._data_connector.prepare_data()
|
||||
assert dm.random_full is not None
|
||||
dm.prepare_data.assert_called_once()
|
||||
|
||||
# global rank = 1 (False)
|
||||
dm.random_full = None
|
||||
dm.reset_mock()
|
||||
node_rank.return_value = 1
|
||||
local_rank.return_value = 0
|
||||
|
||||
trainer._data_connector.prepare_data()
|
||||
assert dm.random_full is None
|
||||
dm.prepare_data.assert_not_called()
|
||||
|
||||
node_rank.return_value = 0
|
||||
local_rank.return_value = 1
|
||||
|
||||
trainer._data_connector.prepare_data()
|
||||
assert dm.random_full is None
|
||||
dm.prepare_data.assert_not_called()
|
||||
|
||||
# 2 dm
|
||||
# prepar per node = True
|
||||
|
@ -92,10 +93,9 @@ def test_can_prepare_data(local_rank, node_rank):
|
|||
dm.prepare_data_per_node = True
|
||||
local_rank.return_value = 0
|
||||
|
||||
with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
|
||||
# is_overridden prepare data = True
|
||||
trainer._data_connector.prepare_data()
|
||||
dm_mock.assert_called_once()
|
||||
# is_overridden prepare data = True
|
||||
trainer._data_connector.prepare_data()
|
||||
dm.prepare_data.assert_called_once()
|
||||
|
||||
|
||||
def test_hooks_no_recursion_error():
|
||||
|
|
|
@ -151,8 +151,6 @@ class BoringDataModule(LightningDataModule):
|
|||
self.data_dir = data_dir
|
||||
self.non_picklable = None
|
||||
self.checkpoint_state: Optional[str] = None
|
||||
|
||||
def prepare_data(self):
|
||||
self.random_full = RandomDataset(32, 64 * 4)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
|
Loading…
Reference in New Issue