# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from typing import List, Optional, Union from torch.utils.data import DataLoader from pytorch_lightning.utilities.model_utils import is_overridden class DataConnector(object): def __init__(self, trainer): self.trainer = trainer def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node): self.trainer.datamodule = None self.trainer.prepare_data_per_node = prepare_data_per_node self.trainer.check_val_every_n_epoch = check_val_every_n_epoch self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader): profiled_dl = self.trainer.profiler.profile_iterable( enumerate(self._with_is_last(train_dataloader)), "get_train_batch" ) return profiled_dl def _with_is_last(self, iterable): """Pass through values from the given iterable with an added boolean indicating if this is the last item. See `https://stackoverflow.com/a/1630350 `_""" it = iter(iterable) last = next(it) for val in it: # yield last and has next yield last, False last = val # yield last, no longer has next yield last, True def prepare_data(self, model): # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): if self.trainer.datamodule is not None: self.trainer.datamodule.prepare_data() model.prepare_data() self.trainer._is_data_prepared = True def can_prepare_data(self): should_call_dm_prepare_data = True if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule): should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data if self.trainer.prepare_data_per_node: return self.trainer.local_rank == 0 and should_call_dm_prepare_data else: return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # if a datamodule comes in as the second arg, then fix it for the user if isinstance(train_dataloader, LightningDataModule): datamodule = train_dataloader train_dataloader = None self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule, 'fit') def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None: # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: # Override loader hooks if is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if is_overridden('val_dataloader', datamodule): model.val_dataloader = datamodule.val_dataloader if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule if is_overridden('transfer_batch_to_device', datamodule): model.transfer_batch_to_device = datamodule.transfer_batch_to_device self.trainer.datamodule = datamodule datamodule.trainer = self.trainer class _PatchDataLoader(object): r""" Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible. Args: dataloader: Dataloader object to return when called. """ def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader # cannot pickle __code__ so cannot verify if PatchDataloader # exists which shows dataloader methods have been overwritten. # so, we hack it by using the string representation self.patch_loader_code = str(self.__call__.__code__) def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader