lightning/pytorch_lightning/trainer/data_loading.py

444 lines
20 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import multiprocessing
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
val_check_interval: float
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: float
val_dataloaders: Optional[List[DataLoader]]
num_val_batches: List[Union[int, float]]
test_dataloaders: Optional[List[DataLoader]]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
accelerator: Accelerator
accelerator_connector: AcceleratorConnector
dev_debugger: InternalDebugger
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
return
using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn"
num_cpus = multiprocessing.cpu_count()
# ddp_spawn + num_workers > 0 don't mix! tell the user
if dataloader.num_workers > 0 and using_spawn:
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'num_workers>0, persistent_workers=False, and accelerator=ddp_spawn'
' may result in data loading bottlenecks.'
' Consider setting persistent_workers=True'
' (this is a limitation of Python .spawn() and PyTorch)'
)
else:
rank_zero_warn(
'num_workers>0 and accelerator=ddp_spawn do not mix well'
' and may result in data loading bottlenecks.'
' Consider setting accelerator=ddp to use num_workers>0'
' (this is a limitation of Python .spawn() and PyTorch)'
)
elif dataloader.num_workers == 0 and using_spawn:
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting num_workers>0 and persistent_workers=True'
)
else:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting accelerator=ddp and set num_workers>0'
)
elif dataloader.num_workers <= 2 < num_cpus and not using_spawn:
rank_zero_warn(
f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
f' (try {num_cpus} which is the number of cpus on this machine)'
f' in the `DataLoader` init to improve performance.'
)
def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = has_iterable_dataset(dataloader)
if isinstance(dataloader, CombinedLoader):
dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
return dataloader
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance(
dataloader.sampler, DistributedSampler
)
if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.'
)
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, shuffle)
dataloader = self.replace_sampler(dataloader, sampler)
return dataloader
@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler):
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
dl_args['sampler'] = None
dl_args['drop_last'] = False
else:
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None
return dl_args
def replace_sampler(self, dataloader, sampler):
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True
if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)
dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}
dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler)
multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context
missing_kwargs = params.difference(skip_signature_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
)
if not contains_dataset:
dl_args.pop('dataset')
dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
return dataloader
def _get_distributed_sampler(self, dataloader, shuffle):
kwargs = self.distributed_sampler_kwargs
kwargs["shuffle"] = shuffle and not self.overfit_batches
if _TORCH_GREATER_EQUAL_1_6:
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model, "train")
if self.overfit_batches > 0:
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn(
'You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.'
)
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
)
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
)
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.'
)
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
)
else:
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.'
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(
self,
model: LightningModule,
mode: str,
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
Returns:
Tuple (num_batches, dataloaders)
"""
# always get the loaders first so we can count how many there are
loader_name = f'{mode}_dataloader'
dataloaders = self.request_dataloader(model, mode)
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# when overfitting use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
num_loaders = len(dataloaders)
train_dataloader = self.request_dataloader(model, 'train')
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
modes = ('val', 'test', 'predict')
if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode != 'predict':
rank_zero_warn(
'You requested to overfit but enabled val/test dataloader shuffling.'
' We are turning it off for you.'
)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
else:
rank_zero_warn(
f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for val/test/predict dataloaders.'
)
if any([dl is None for dl in dataloaders]):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
# add samplers
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
loader_num_batches = []
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.'
)
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
f' {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches.'
f' Try at least limit_{mode}_batches={min_pct}'
)
loader_num_batches.append(num_batches)
return loader_num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = is_overridden('val_dataloader', model)
has_step = is_overridden('validation_step', model)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:
"""Resets the test dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = is_overridden('test_dataloader', model)
has_step = is_overridden('test_step', model)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, 'test')
def reset_predict_dataloader(self, model) -> None:
"""Resets the predict dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = is_overridden('predict_dataloader', model)
if has_loader:
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict')
def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
dataloader_fx: The bound dataloader getter
Returns:
The dataloader
"""
if model.trainer is not None:
model.trainer.call_hook(f"on_{stage}_dataloader")
dataloader: DataLoader = getattr(model, f'{stage}_dataloader')()
dataloader = self._flatten_dl_only(dataloader)
self.accelerator.barrier('get_dataloaders')
return dataloader
def _flatten_dl_only(self, dataloaders):
# handles user error when they return:
# return dl1, dl2 vs return (dl1, dl2)
if isinstance(dataloaders, tuple):
all_dls = [isinstance(x, Iterable) for x in dataloaders]
all_dls = all(all_dls)
if all_dls:
dataloaders = list(dataloaders)
return dataloaders