354 lines
15 KiB
354 lines
15 KiB
# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import platform
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable, Optional
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
from apex import amp
except ImportError:
amp = None
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
global_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: List[DataLoader]
num_val_batches: List[Union[int, float]]
test_dataloaders: List[DataLoader]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
dev_debugger: InternalDebugger
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
on_windows = platform.system() == 'Windows'
# ddp_spawn + num_workers > 0 don't mix! tell the user
is_dataloader = isinstance(dataloader, DataLoader)
using_spawn = self.distributed_backend == 'ddp_spawn'
if is_dataloader and not on_windows:
if dataloader.num_workers > 0 and using_spawn:
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting distributed_backend=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch')
elif dataloader.num_workers == 0 and using_spawn:
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `distributed_backend=ddp`'
' and set `num_workers>0`')
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
num_cpus = multiprocessing.cpu_count()
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
f' (try {num_cpus} which is the number of cpus on this machine)'
' in the `DataLoader` init to improve performance.')
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = has_iterable_dataset(dataloader)
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, train)
dataloader = self.replace_sampler(dataloader, sampler)
return dataloader
def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
dl_args['sampler'] = sampler
dataloader = type(dataloader)(**dl_args)
return dataloader
def _get_distributed_sampler(self, dataloader, train):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp_spawn': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
kwargs['shuffle'] = train and not self.overfit_batches
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
model: The current `LightningModule`
self.train_dataloader = self.request_dataloader(model.train_dataloader)
if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.')
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float('inf'):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
raise MisconfigurationException(
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(
model: LightningModule,
mode: str
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
Tuple (num_batches, dataloaders)
# use the training loader as val and test when overfitting
loader_name = f'{mode}_dataloader'
if self.overfit_batches > 0:
loader_name = 'train_dataloader'
# load loaders
dataloaders = self.request_dataloader(getattr(model, loader_name))
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
' We are turning it off for you.')
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for validation and test dataloaders.')
if any([dl is None for dl in dataloaders]):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
loader_num_batches = []
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float('inf'):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
f' {limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.'
f' Try at least limit_{mode}_batches={min_pct}'
return loader_num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
model: The current `LightningModule`
has_loader = is_overridden('val_dataloader', model)
has_step = is_overridden('validation_step', model)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
model: The current `LightningModule`
has_loader = is_overridden('test_dataloader', model)
has_step = is_overridden('test_step', model)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
dataloader_fx: The bound dataloader getter
The dataloader
dataloader = dataloader_fx()
# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
# all processes wait until data download has happened
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
# all processes wait until data download has happened
elif self.use_horovod:
# all processes wait until data download has happened
return dataloader