Fix `prepare_data` implementation in `BoringDataModule` (#10915)

This commit is contained in:
Adrian Wälchli 2021-12-03 14:51:34 +01:00 committed by GitHub
parent b7331d80dc
commit 553d429ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 17 deletions

View File

@ -16,7 +16,7 @@ from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
from unittest.mock import call, PropertyMock
from unittest.mock import call, Mock, PropertyMock
import pytest
import torch
@ -40,51 +40,52 @@ if _OMEGACONF_AVAILABLE:
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):
dm = BoringDataModule()
dm = Mock(spec=LightningDataModule)
dm.prepare_data_per_node = True
trainer = Trainer()
trainer.datamodule = dm
# 1 no DM
# prepare_data_per_node = True
# local rank = 0 (True)
dm.random_full = None
dm.prepare_data.assert_not_called()
local_rank.return_value = 0
assert trainer.local_rank == 0
trainer._data_connector.prepare_data()
assert dm.random_full is not None
dm.prepare_data.assert_called_once()
# local rank = 1 (False)
dm.random_full = None
dm.reset_mock()
local_rank.return_value = 1
assert trainer.local_rank == 1
trainer._data_connector.prepare_data()
assert dm.random_full is None
dm.prepare_data.assert_not_called()
# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
dm.random_full = None
dm.reset_mock()
dm.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
trainer._data_connector.prepare_data()
assert dm.random_full is not None
dm.prepare_data.assert_called_once()
# global rank = 1 (False)
dm.random_full = None
dm.reset_mock()
node_rank.return_value = 1
local_rank.return_value = 0
trainer._data_connector.prepare_data()
assert dm.random_full is None
dm.prepare_data.assert_not_called()
node_rank.return_value = 0
local_rank.return_value = 1
trainer._data_connector.prepare_data()
assert dm.random_full is None
dm.prepare_data.assert_not_called()
# 2 dm
# prepar per node = True
@ -92,10 +93,9 @@ def test_can_prepare_data(local_rank, node_rank):
dm.prepare_data_per_node = True
local_rank.return_value = 0
with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
# is_overridden prepare data = True
trainer._data_connector.prepare_data()
dm_mock.assert_called_once()
# is_overridden prepare data = True
trainer._data_connector.prepare_data()
dm.prepare_data.assert_called_once()
def test_hooks_no_recursion_error():

View File

@ -151,8 +151,6 @@ class BoringDataModule(LightningDataModule):
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None
def prepare_data(self):
self.random_full = RandomDataset(32, 64 * 4)
def setup(self, stage: Optional[str] = None):