319 lines
15 KiB
Python
319 lines
15 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import Iterable, Optional, Union
|
|
from weakref import proxy
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities import rank_zero_deprecation
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.fetching import (
|
|
AbstractDataFetcher,
|
|
DataFetcher,
|
|
DataLoaderIterDataFetcher,
|
|
InterBatchParallelDataFetcher,
|
|
)
|
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
|
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
|
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
|
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
|
|
|
|
|
class DataConnector:
|
|
def __init__(
|
|
self,
|
|
trainer: "pl.Trainer",
|
|
multiple_trainloader_mode: str = "max_size_cycle",
|
|
train_data_fetcher: Optional[AbstractDataFetcher] = None,
|
|
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
|
|
test_data_fetcher: Optional[AbstractDataFetcher] = None,
|
|
):
|
|
self.trainer = trainer
|
|
self.multiple_trainloader_mode = multiple_trainloader_mode
|
|
|
|
self.train_data_fetcher = train_data_fetcher
|
|
self.validate_data_fetcher = validate_data_fetcher
|
|
self.test_data_fetcher = test_data_fetcher
|
|
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None
|
|
|
|
self._train_dataloader_source = _DataLoaderSource(None, "")
|
|
self._val_dataloader_source = _DataLoaderSource(None, "")
|
|
self._test_dataloader_source = _DataLoaderSource(None, "")
|
|
self._predict_dataloader_source = _DataLoaderSource(None, "")
|
|
|
|
@property
|
|
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
|
|
if self.trainer.sanity_checking:
|
|
return self.sanity_check_data_fetcher
|
|
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher
|
|
|
|
def on_trainer_init(
|
|
self,
|
|
check_val_every_n_epoch: int,
|
|
reload_dataloaders_every_n_epochs: int,
|
|
reload_dataloaders_every_epoch: bool,
|
|
prepare_data_per_node: Optional[bool] = None,
|
|
) -> None:
|
|
self.trainer.datamodule = None
|
|
|
|
if prepare_data_per_node is not None:
|
|
rank_zero_deprecation(
|
|
"Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0! "
|
|
"Please set `prepare_data_per_node` in LightningDataModule or LightningModule directly instead. "
|
|
)
|
|
self.trainer.prepare_data_per_node = prepare_data_per_node
|
|
|
|
if not isinstance(check_val_every_n_epoch, int):
|
|
raise MisconfigurationException(
|
|
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
|
|
)
|
|
|
|
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
if reload_dataloaders_every_epoch:
|
|
reload_dataloaders_every_n_epochs = int(reload_dataloaders_every_epoch)
|
|
rank_zero_deprecation(
|
|
"`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed in v1.6."
|
|
" Please use `reload_dataloaders_every_n_epochs` in Trainer."
|
|
)
|
|
|
|
if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0):
|
|
raise MisconfigurationException(
|
|
f"`reload_dataloaders_every_n_epochs` should be an int >= 0, got {reload_dataloaders_every_n_epochs}."
|
|
)
|
|
|
|
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
|
|
self.trainer._is_data_prepared = False
|
|
|
|
def _select_data_fetcher(self) -> AbstractDataFetcher:
|
|
if self.trainer.sanity_checking:
|
|
return DataFetcher()
|
|
|
|
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
|
if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
|
rank_zero_warn(
|
|
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
|
|
"this signature is experimental and the behavior is subject to change."
|
|
)
|
|
return DataLoaderIterDataFetcher()
|
|
|
|
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
|
|
# note: this is an experimental feature
|
|
if not self.trainer.training_type_plugin.on_gpu:
|
|
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
|
|
return InterBatchParallelDataFetcher()
|
|
|
|
return DataFetcher()
|
|
|
|
def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable:
|
|
stage: str = self.trainer.state.stage.value
|
|
data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
|
|
data_fetcher.setup(
|
|
dataloader,
|
|
stage=stage,
|
|
batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx),
|
|
profiler=self.trainer.profiler,
|
|
)
|
|
setattr(self, f"{stage}_data_fetcher", data_fetcher)
|
|
return data_fetcher
|
|
|
|
def prepare_data(self) -> None:
|
|
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
|
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
|
local_rank_zero = self.trainer.local_rank == 0
|
|
global_rank_zero = self.trainer.local_rank == 0 and self.trainer.node_rank == 0
|
|
|
|
datamodule = self.trainer.datamodule
|
|
lightning_module = self.trainer.lightning_module
|
|
# handle datamodule prepare data:
|
|
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
|
|
if datamodule is not None:
|
|
dm_prepare_data_per_node = datamodule.prepare_data_per_node
|
|
dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
|
|
if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
|
|
raise MisconfigurationException(
|
|
"Inconsistent settings found for `prepare_data_per_node`."
|
|
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
|
|
f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`."
|
|
" Move `prepare_data_per_node` setting to DataModule property."
|
|
)
|
|
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
|
|
self.trainer.datamodule.prepare_data()
|
|
# handle lightning module prepare data:
|
|
# check for prepare_data_per_node before calling lightning_module.prepare_data
|
|
if lightning_module is not None:
|
|
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
|
|
lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node
|
|
if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data:
|
|
raise MisconfigurationException(
|
|
"Inconsistent settings found for `prepare_data_per_node`."
|
|
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
|
|
f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`."
|
|
" Move `prepare_data_per_node` setting to LightningModule property."
|
|
)
|
|
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
|
|
self.trainer.call_hook("prepare_data")
|
|
self.trainer._is_data_prepared = True
|
|
|
|
def attach_data(
|
|
self,
|
|
model: "pl.LightningModule",
|
|
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
|
|
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
datamodule: Optional["pl.LightningDataModule"] = None,
|
|
) -> None:
|
|
# set up the passed in dataloaders (if needed)
|
|
self.attach_dataloaders(
|
|
model,
|
|
train_dataloaders=train_dataloaders,
|
|
val_dataloaders=val_dataloaders,
|
|
test_dataloaders=test_dataloaders,
|
|
predict_dataloaders=predict_dataloaders,
|
|
)
|
|
self.attach_datamodule(model, datamodule=datamodule)
|
|
# set local properties on the model
|
|
self._copy_trainer_model_properties(model)
|
|
|
|
def _copy_trainer_model_properties(self, model):
|
|
ref_model = self.trainer.lightning_module or model
|
|
|
|
for m in [model, ref_model]:
|
|
m.trainer = proxy(self.trainer)
|
|
m.use_amp = self.trainer.amp_backend is not None
|
|
m.precision = self.trainer.precision
|
|
|
|
def attach_dataloaders(
|
|
self,
|
|
model: "pl.LightningModule",
|
|
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
|
|
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
|
|
) -> None:
|
|
self.trainer.train_dataloader = None
|
|
self.trainer.val_dataloaders = None
|
|
self.trainer.test_dataloaders = None
|
|
self.trainer.predict_dataloaders = None
|
|
|
|
self._train_dataloader_source = _DataLoaderSource(
|
|
train_dataloaders if train_dataloaders is not None else model, "train_dataloader"
|
|
)
|
|
self._val_dataloader_source = _DataLoaderSource(
|
|
val_dataloaders if val_dataloaders is not None else model, "val_dataloader"
|
|
)
|
|
self._test_dataloader_source = _DataLoaderSource(
|
|
test_dataloaders if test_dataloaders is not None else model, "test_dataloader"
|
|
)
|
|
self._predict_dataloader_source = _DataLoaderSource(
|
|
predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader"
|
|
)
|
|
|
|
def attach_datamodule(
|
|
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
|
|
) -> None:
|
|
# If we have a datamodule, attach necessary hooks + dataloaders
|
|
if datamodule is None:
|
|
return
|
|
|
|
self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader")
|
|
self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader")
|
|
self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader")
|
|
self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader")
|
|
|
|
# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
|
|
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
|
|
for hook in batch_transfer_hooks:
|
|
if is_overridden(hook, datamodule):
|
|
setattr(model, hook, getattr(datamodule, hook))
|
|
|
|
self.trainer.datamodule = datamodule
|
|
datamodule.trainer = self.trainer
|
|
|
|
# experimental feature for Flash
|
|
if hasattr(datamodule, "data_pipeline"):
|
|
model.data_pipeline = datamodule.data_pipeline
|
|
|
|
def teardown(self) -> None:
|
|
if self.train_data_fetcher:
|
|
self.train_data_fetcher.teardown()
|
|
self.train_data_fetcher = None
|
|
if self.validate_data_fetcher:
|
|
self.validate_data_fetcher.teardown()
|
|
self.validate_data_fetcher = None
|
|
if self.test_data_fetcher:
|
|
self.test_data_fetcher.teardown()
|
|
self.test_data_fetcher = None
|
|
if self.sanity_check_data_fetcher:
|
|
self.sanity_check_data_fetcher.teardown()
|
|
self.sanity_check_data_fetcher = None
|
|
|
|
|
|
@dataclass
|
|
class _DataLoaderSource:
|
|
"""Stores the information where the dataloaders come from.
|
|
|
|
The source can be
|
|
|
|
1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`,
|
|
2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
|
|
3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof.
|
|
|
|
Arguments:
|
|
instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s).
|
|
name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
|
|
that returns the desired dataloader(s).
|
|
"""
|
|
|
|
instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]]
|
|
name: str
|
|
|
|
def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
|
|
"""Returns the dataloader from the source.
|
|
|
|
If the source is a module, the method with the corresponding :attr:`name` gets called.
|
|
"""
|
|
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
|
|
|
|
if not self.name:
|
|
return self.instance
|
|
|
|
if isinstance(self.instance, LightningModule):
|
|
return self.instance.trainer.call_hook(self.name, pl_module=self.instance)
|
|
|
|
if isinstance(self.instance, LightningDataModule):
|
|
method = getattr(self.instance, self.name)
|
|
return method()
|
|
|
|
return self.instance
|
|
|
|
def is_defined(self) -> bool:
|
|
"""Returns whether the source dataloader can be retrieved or not.
|
|
|
|
If the source is a module it checks that the method with given :attr:`name` is overridden.
|
|
"""
|
|
return not self.is_module() or is_overridden(self.name, self.instance)
|
|
|
|
def is_module(self) -> bool:
|
|
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
|
|
|
|
It does not check whether ``*_dataloader`` methods are actually overridden.
|
|
"""
|
|
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
|
|
|
|
return isinstance(self.instance, (LightningModule, LightningDataModule))
|