# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import Mock import pytest from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource from tests.helpers import BoringDataModule, BoringModel class NoDataLoaderModel(BoringModel): def __init__(self): super().__init__() self.train_dataloader = None self.val_dataloader = None self.test_dataloader = None self.predict_dataloader = None @pytest.mark.parametrize( "instance,available", [ (None, True), (BoringModel().train_dataloader(), True), (BoringModel(), True), (NoDataLoaderModel(), False), (BoringDataModule(), True), ], ) def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") assert source.is_defined() is available def test_dataloader_source_direct_access(): """Test requesting a dataloader when the source is already a dataloader.""" dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() assert source.is_defined() assert source.dataloader() is dataloader def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() module.trainer = Trainer() module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") assert source.is_module() module.foo.assert_not_called() assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once()