lightning/pytorch_lightning/trainer/data_loading.py

370 lines
16 KiB
Python

import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional
import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False
try:
from apex import amp
except ImportError:
amp = None
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)
def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
global_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: List[DataLoader]
num_val_batches: List[Union[int, float]]
test_dataloaders: List[DataLoader]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
@abstractmethod
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
on_windows = platform.system() == 'Windows'
# ddp_spawn + num_workers > 0 don't mix! tell the user
is_dataloader = isinstance(dataloader, DataLoader)
using_spawn = self.distributed_backend == 'ddp_spawn'
if is_dataloader and not on_windows:
if dataloader.num_workers > 0 and using_spawn:
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting distributed_backend=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch')
elif dataloader.num_workers == 0 and using_spawn:
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `distributed_backend=ddp`'
' and set `num_workers>0`')
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
num_cpus = multiprocessing.cpu_count()
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
f' (try {num_cpus} which is the number of cpus on this machine)'
' in the `DataLoader` init to improve performance.')
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, train)
dataloader = self.replace_sampler(dataloader, sampler)
return dataloader
def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}
dl_args['sampler'] = sampler
dataloader = type(dataloader)(**dl_args)
return dataloader
def _get_distributed_sampler(self, dataloader, train):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp_spawn': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
kwargs['shuffle'] = train
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(
self,
model: LightningModule,
mode: str
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
Returns:
Tuple (num_batches, dataloaders)
"""
# use the training loader as val and test when overfitting
if self.overfit_batches > 0:
dataloaders = self.request_dataloader(getattr(model, 'train_dataloader'))
else:
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.')
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
else:
rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for validation and test dataloaders.')
if any([dl is None for dl in dataloaders]):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
loader_num_batches = []
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
f' {limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.'
f' Try at least limit_{mode}_batches={min_pct}'
)
loader_num_batches.append(num_batches)
return loader_num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = self.is_overridden('val_dataloader', model)
has_step = self.is_overridden('validation_step', model)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = self.is_overridden('test_dataloader', model)
has_step = self.is_overridden('test_step', model)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
dataloader_fx: The bound dataloader getter
Returns:
The dataloader
"""
dataloader = dataloader_fx()
# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
# all processes wait until data download has happened
torch_distrib.barrier()
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
elif self.use_horovod:
# all processes wait until data download has happened
hvd.join()
return dataloader
def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches