lightning/pytorch_lightning/trainer/connectors/data_connector.py

177 lines
7.4 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
def on_trainer_init(
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
reload_dataloaders_every_epoch: bool,
prepare_data_per_node: bool,
) -> None:
self.trainer.datamodule = None
self.trainer.prepare_data_per_node = prepare_data_per_node
if not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
)
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
if reload_dataloaders_every_epoch:
reload_dataloaders_every_n_epochs = int(reload_dataloaders_every_epoch)
rank_zero_deprecation(
"`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed in v1.6."
" Please use `reload_dataloaders_every_n_epochs` in Trainer."
)
if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0):
raise MisconfigurationException(
"`reload_dataloaders_every_n_epochs` should be an int >= 0,"
f" got {reload_dataloaders_every_n_epochs}."
)
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
return profiled_dl
def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
model.prepare_data()
self.trainer._is_data_prepared = True
def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
should_call_dm_prepare_data = not self.trainer.datamodule._has_prepared_data
if self.trainer.prepare_data_per_node:
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
def attach_data(
self,
model: 'pl.LightningModule',
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(
model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)
self.attach_datamodule(model, datamodule=datamodule)
# set local properties on the model
self.trainer.model_connector.copy_trainer_model_properties(model)
def attach_dataloaders(
self,
model: 'pl.LightningModule',
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloaders is not None:
model.train_dataloader = _PatchDataLoader(train_dataloaders)
if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
def attach_datamodule(
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule is None:
return
# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))
# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))
self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer
# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline
class _PatchDataLoader:
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
"""
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
self.dataloader = dataloader
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader