356 lines
16 KiB
Python
356 lines
16 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import multiprocessing
|
|
import platform
|
|
from abc import ABC, abstractmethod
|
|
from typing import Union, List, Tuple, Callable, Optional
|
|
|
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
|
from pytorch_lightning.core import LightningModule
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.debugging import InternalDebugger
|
|
from pytorch_lightning.utilities.model_utils import is_overridden
|
|
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
|
|
from copy import deepcopy
|
|
from typing import Iterable
|
|
|
|
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
|
try:
|
|
from apex import amp
|
|
except ImportError:
|
|
amp = None
|
|
|
|
if TPU_AVAILABLE:
|
|
import torch_xla
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
try:
|
|
import horovod.torch as hvd
|
|
except (ModuleNotFoundError, ImportError):
|
|
HOROVOD_AVAILABLE = False
|
|
else:
|
|
HOROVOD_AVAILABLE = True
|
|
|
|
|
|
class TrainerDataLoadingMixin(ABC):
|
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
# the proper values/initialisation should be done in child class
|
|
global_rank: int
|
|
use_ddp: bool
|
|
use_ddp2: bool
|
|
use_horovod: bool
|
|
shown_warnings: ...
|
|
val_check_interval: float
|
|
use_tpu: bool
|
|
tpu_local_core_rank: int
|
|
train_dataloader: DataLoader
|
|
num_training_batches: Union[int, float]
|
|
val_check_batch: ...
|
|
val_dataloaders: List[DataLoader]
|
|
num_val_batches: List[Union[int, float]]
|
|
test_dataloaders: List[DataLoader]
|
|
num_test_batches: List[Union[int, float]]
|
|
limit_train_batches: Union[int, float]
|
|
limit_val_batches: Union[int, float]
|
|
limit_test_batches: Union[int, float]
|
|
replace_sampler_ddp: bool
|
|
accelerator_backend: Accelerator
|
|
num_nodes: int
|
|
num_processes: int
|
|
distributed_backend: Optional[str]
|
|
dev_debugger: InternalDebugger
|
|
|
|
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
|
|
on_windows = platform.system() == 'Windows'
|
|
|
|
# ddp_spawn + num_workers > 0 don't mix! tell the user
|
|
is_dataloader = isinstance(dataloader, DataLoader)
|
|
using_spawn = self.distributed_backend == "ddp_spawn"
|
|
if is_dataloader and not on_windows:
|
|
if dataloader.num_workers > 0 and using_spawn:
|
|
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
|
|
' Your performance might suffer dramatically.'
|
|
' Please consider setting distributed_backend=ddp to use num_workers > 0'
|
|
' (this is a bottleneck of Python .spawn() and PyTorch')
|
|
|
|
elif dataloader.num_workers == 0 and using_spawn:
|
|
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0.'
|
|
' For much faster performance, switch to `distributed_backend=ddp`'
|
|
' and set `num_workers>0`')
|
|
|
|
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
|
|
num_cpus = multiprocessing.cpu_count()
|
|
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
|
|
' Consider increasing the value of the `num_workers` argument`'
|
|
f' (try {num_cpus} which is the number of cpus on this machine)'
|
|
' in the `DataLoader` init to improve performance.')
|
|
|
|
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
|
|
|
|
# don't do anything if it's not a dataloader
|
|
is_dataloader = isinstance(dataloader, DataLoader)
|
|
# don't manipulate iterable datasets
|
|
is_iterable_ds = has_iterable_dataset(dataloader)
|
|
|
|
if not is_dataloader or is_iterable_ds:
|
|
return dataloader
|
|
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
|
|
|
|
if self.replace_sampler_ddp and need_dist_sampler:
|
|
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
|
|
raise MisconfigurationException(
|
|
'You seem to have configured a sampler in your DataLoader. This will be replaced '
|
|
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
|
|
' distributed training. Either remove the sampler from your DataLoader or set'
|
|
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
|
|
|
|
# replace with distributed sampler
|
|
sampler = self._get_distributed_sampler(dataloader, train)
|
|
dataloader = self.replace_sampler(dataloader, sampler)
|
|
|
|
return dataloader
|
|
|
|
def replace_sampler(self, dataloader, sampler):
|
|
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
|
|
|
dl_args = {
|
|
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
|
|
}
|
|
|
|
dl_args['sampler'] = sampler
|
|
dataloader = type(dataloader)(**dl_args)
|
|
return dataloader
|
|
|
|
def _get_distributed_sampler(self, dataloader, train):
|
|
if self.use_tpu:
|
|
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
|
elif self.use_horovod:
|
|
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
|
|
else:
|
|
world_size = {
|
|
"ddp": self.num_nodes * self.num_processes,
|
|
"ddp_spawn": self.num_nodes * self.num_processes,
|
|
"ddp2": self.num_nodes,
|
|
"ddp_cpu": self.num_processes * self.num_nodes
|
|
}
|
|
assert self.distributed_backend is not None
|
|
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
|
|
|
|
kwargs['shuffle'] = train and not self.overfit_batches
|
|
sampler = DistributedSampler(dataloader.dataset, **kwargs)
|
|
return sampler
|
|
|
|
def reset_train_dataloader(self, model: LightningModule) -> None:
|
|
"""Resets the train dataloader and initialises required variables
|
|
(number of batches, when to validate, etc.).
|
|
|
|
Args:
|
|
model: The current `LightningModule`
|
|
"""
|
|
self.train_dataloader = self.request_dataloader(model.train_dataloader)
|
|
if (self.overfit_batches > 0):
|
|
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
|
|
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
|
|
' We are turning it off for you.')
|
|
self.train_dataloader = self.replace_sampler(
|
|
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
|
|
|
|
# debugging
|
|
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
|
|
|
|
self.num_training_batches = 0
|
|
|
|
# automatically add samplers
|
|
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
|
|
|
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
|
|
self._worker_check(self.train_dataloader, 'train dataloader')
|
|
|
|
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
|
|
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
|
|
elif self.num_training_batches != float('inf'):
|
|
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
|
|
elif self.limit_train_batches != 1.0:
|
|
raise MisconfigurationException(
|
|
'When using an IterableDataset for `limit_train_batches`,'
|
|
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
|
' `num_training_batches` to use.')
|
|
|
|
# determine when to check validation
|
|
# if int passed in, val checks that often
|
|
# otherwise, it checks in [0, 1.0] % range of a training epoch
|
|
if isinstance(self.val_check_interval, int):
|
|
self.val_check_batch = self.val_check_interval
|
|
if self.val_check_batch > self.num_training_batches:
|
|
raise ValueError(
|
|
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
|
|
f'to the number of the training batches ({self.num_training_batches}). '
|
|
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
|
|
else:
|
|
if not has_len(self.train_dataloader):
|
|
if self.val_check_interval == 1.0:
|
|
self.val_check_batch = float('inf')
|
|
else:
|
|
raise MisconfigurationException(
|
|
'When using an IterableDataset for `train_dataloader`,'
|
|
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
|
|
' checking validation every k training batches.')
|
|
else:
|
|
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
|
self.val_check_batch = max(1, self.val_check_batch)
|
|
|
|
def _reset_eval_dataloader(
|
|
self,
|
|
model: LightningModule,
|
|
mode: str
|
|
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
|
|
"""Generic method to reset a dataloader for evaluation.
|
|
|
|
Args:
|
|
model: The current `LightningModule`
|
|
mode: Either `'val'` or `'test'`
|
|
|
|
Returns:
|
|
Tuple (num_batches, dataloaders)
|
|
"""
|
|
# always get the loaders first so we can count how many there are
|
|
loader_name = f'{mode}_dataloader'
|
|
dataloaders = self.request_dataloader(getattr(model, loader_name))
|
|
|
|
if not isinstance(dataloaders, list):
|
|
dataloaders = [dataloaders]
|
|
|
|
# when overfitting use the training loader as val and test
|
|
# duplicate it the numb of times needed to match the train loaders
|
|
if self.overfit_batches > 0:
|
|
num_loaders = len(dataloaders)
|
|
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
|
|
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
|
|
|
|
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
|
|
|
|
for loader_i in range(len(dataloaders)):
|
|
loader = dataloaders[loader_i]
|
|
|
|
# shuffling in val and test set is bad practice
|
|
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
|
|
|
|
# when overfitting, the dataloader should not have sampler
|
|
if self.overfit_batches > 0:
|
|
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
|
|
' We are turning it off for you.')
|
|
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
|
|
|
|
else:
|
|
rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
|
|
' this off for validation and test dataloaders.')
|
|
|
|
if any([dl is None for dl in dataloaders]):
|
|
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
|
|
|
|
# add samplers
|
|
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
|
|
|
|
loader_num_batches = []
|
|
|
|
# determine number of batches
|
|
# datasets could be none, 1 or 2+
|
|
if len(dataloaders) != 0:
|
|
for i, dataloader in enumerate(dataloaders):
|
|
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
|
|
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
|
|
|
# percent or num_steps
|
|
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
|
|
|
|
# limit num batches either as a percent or num steps
|
|
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
|
|
num_batches = min(num_batches, int(limit_eval_batches))
|
|
elif num_batches != float('inf'):
|
|
num_batches = int(num_batches * limit_eval_batches)
|
|
elif limit_eval_batches != 1.0:
|
|
raise MisconfigurationException(
|
|
'When using an IterableDataset for `limit_{mode}_batches`,'
|
|
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
|
f' `num_{mode}_batches` to use.')
|
|
|
|
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
|
|
min_pct = 1.0 / len(dataloader)
|
|
raise MisconfigurationException(
|
|
f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
|
|
f' {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches.'
|
|
f' Try at least limit_{mode}_batches={min_pct}'
|
|
)
|
|
|
|
loader_num_batches.append(num_batches)
|
|
|
|
return loader_num_batches, dataloaders
|
|
|
|
def reset_val_dataloader(self, model: LightningModule) -> None:
|
|
"""Resets the validation dataloader and determines the number of batches.
|
|
|
|
Args:
|
|
model: The current `LightningModule`
|
|
"""
|
|
has_loader = is_overridden('val_dataloader', model)
|
|
has_step = is_overridden('validation_step', model)
|
|
if has_loader and has_step:
|
|
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
|
|
|
|
def reset_test_dataloader(self, model) -> None:
|
|
"""Resets the validation dataloader and determines the number of batches.
|
|
|
|
Args:
|
|
model: The current `LightningModule`
|
|
"""
|
|
has_loader = is_overridden('test_dataloader', model)
|
|
has_step = is_overridden('test_step', model)
|
|
if has_loader and has_step:
|
|
self.num_test_batches, self.test_dataloaders =\
|
|
self._reset_eval_dataloader(model, 'test')
|
|
|
|
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
|
|
"""Handles downloading data in the GPU or TPU case.
|
|
|
|
Args:
|
|
dataloader_fx: The bound dataloader getter
|
|
|
|
Returns:
|
|
The dataloader
|
|
"""
|
|
dataloader = dataloader_fx()
|
|
dataloader = self._flatten_dl_only(dataloader)
|
|
|
|
if self.accelerator_backend is not None:
|
|
self.accelerator_backend.barrier('get_dataloaders')
|
|
return dataloader
|
|
|
|
def _flatten_dl_only(self, dataloaders):
|
|
# handles user error when they return:
|
|
# return dl1, dl2 vs return (dl1, dl2)
|
|
if isinstance(dataloaders, tuple):
|
|
all_dls = [isinstance(x, Iterable) for x in dataloaders]
|
|
all_dls = all(all_dls)
|
|
if all_dls:
|
|
dataloaders = list(dataloaders)
|
|
|
|
return dataloaders
|