lightning/pytorch_lightning/trainer/data_loading.py

254 lines
9.5 KiB
Python

from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable
import torch.distributed as torch_distrib
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
def _has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
_ = len(dataloader)
return True
except TypeError:
return False
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
proc_rank: int
use_ddp: bool
use_ddp2: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: List[DataLoader]
num_val_batches: Union[int, float]
test_dataloaders: List[DataLoader]
num_test_batches: Union[int, float]
train_percent_check: float
val_percent_check: float
test_percent_check: float
@abstractmethod
def is_overriden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def _percent_range_check(self, name: str) -> None:
value = getattr(self, name)
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.'
if name == 'val_check_interval':
msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.'
if not 0. <= value <= 1.:
raise ValueError(msg)
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}
if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal()
)
dl_args['shuffle'] = False
else:
if train:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
dl_args['sampler'] = sampler
dataloader = DataLoader(**dl_args)
return dataloader
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self._percent_range_check('train_percent_check')
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
'checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
Returns:
Tuple (num_batches, dataloaders)
"""
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]
num_batches = 0
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
if not _has_len(dataloader):
num_batches = float('inf')
break
percent_check = getattr(self, f'{mode}_percent_check')
if num_batches != float('inf'):
self._percent_range_check(f'{mode}_percent_check')
num_batches = sum(len(dataloader) for dataloader in dataloaders)
num_batches = int(num_batches * percent_check)
elif percent_check not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, '
f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
return num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
if self.is_overriden('validation_step'):
self.num_val_batches, self.val_dataloaders =\
self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
if self.is_overriden('test_step'):
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
dataloader_fx: The bound dataloader getter
Returns:
The dataloader
"""
dataloader = dataloader_fx()
# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
# all processes wait until data download has happened
torch_distrib.barrier()
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
return dataloader
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
test_percent_check: float, overfit_pct: float) -> None:
"""Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(
f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.')
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct