From 553d429ecb5e832c83356672df54056c92704313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 14:51:34 +0100 Subject: [PATCH] Fix `prepare_data` implementation in `BoringDataModule` (#10915) --- tests/core/test_datamodules.py | 30 +++++++++++++++--------------- tests/helpers/boring_model.py | 2 -- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 57574c074c..c99cc39516 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -16,7 +16,7 @@ from argparse import ArgumentParser, Namespace from dataclasses import dataclass from typing import Any, Dict from unittest import mock -from unittest.mock import call, PropertyMock +from unittest.mock import call, Mock, PropertyMock import pytest import torch @@ -40,51 +40,52 @@ if _OMEGACONF_AVAILABLE: @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - dm = BoringDataModule() + dm = Mock(spec=LightningDataModule) + dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) - dm.random_full = None + dm.prepare_data.assert_not_called() local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() - assert dm.random_full is not None + dm.prepare_data.assert_called_once() # local rank = 1 (False) - dm.random_full = None + dm.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) - dm.random_full = None + dm.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() - assert dm.random_full is not None + dm.prepare_data.assert_called_once() # global rank = 1 (False) - dm.random_full = None + dm.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() # 2 dm # prepar per node = True @@ -92,10 +93,9 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data_per_node = True local_rank.return_value = 0 - with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: - # is_overridden prepare data = True - trainer._data_connector.prepare_data() - dm_mock.assert_called_once() + # is_overridden prepare data = True + trainer._data_connector.prepare_data() + dm.prepare_data.assert_called_once() def test_hooks_no_recursion_error(): diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index d51fb44bff..3a1c4f30fe 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -151,8 +151,6 @@ class BoringDataModule(LightningDataModule): self.data_dir = data_dir self.non_picklable = None self.checkpoint_state: Optional[str] = None - - def prepare_data(self): self.random_full = RandomDataset(32, 64 * 4) def setup(self, stage: Optional[str] = None):