lightning/pytorch_lightning/trainer/data_loading.py

262 lines
9.3 KiB
Python

from abc import ABC
import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False
class TrainerDataLoadingMixin(ABC):
def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.proc_rank = None
self.use_ddp = None
self.use_ddp2 = None
self.shown_warnings = None
self.val_check_interval = None
self.use_tpu = None
self.tpu_local_core_rank = None
self.train_dataloader = None
self.num_training_batches = None
self.val_check_batch = None
self.val_dataloaders = None
self.num_val_batches = None
self.test_dataloaders = None
self.num_test_batches = None
def _percent_range_check(self, name):
value = getattr(self, name)
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
if name == "val_check_interval":
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
if not 0. <= value <= 1.:
raise ValueError(msg)
def call_prepare_data(self, model):
"""
Let model download the data on proc==0 only
:param model:
"""
# download data on DDP+
if self.use_ddp or self.use_ddp2:
if self.proc_rank == 0:
model.prepare_data()
# all processes wait until data download has happened
dist.barrier()
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
if self.tpu_local_core_rank == 0:
model.prepare_data()
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
else:
# regular download
model.prepare_data()
def auto_add_sampler(self, dataloader, train):
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}
if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal()
)
dl_args['shuffle'] = False
else:
if train:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
dl_args['sampler'] = sampler
dataloader = DataLoader(**dl_args)
return dataloader
def reset_train_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.train_dataloader = self.request_data_loader(model.train_dataloader)
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self._percent_range_check('train_percent_check')
if self.is_infinite_dataloader(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)
self._percent_range_check('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True
def reset_val_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
if not self.is_overriden('validation_step'):
return
self.val_dataloaders = self.request_data_loader(model.val_dataloader)
if not isinstance(self.val_dataloaders, list):
self.val_dataloaders = [self.val_dataloaders]
self.num_val_batches = 0
# add samplers
self.val_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.val_dataloaders if dl]
# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.val_dataloaders is not None:
self._percent_range_check('val_percent_check')
self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
def reset_test_dataloader(self, model):
"""Dataloaders are provided by the model.
:param model:
"""
if not self.is_overriden('test_step'):
return
# get actual loader
self.test_dataloaders = self.request_data_loader(model.test_dataloader)
if not isinstance(self.test_dataloaders, list):
self.test_dataloaders = [self.test_dataloaders]
self.num_test_batches = 0
# add samplers
self.test_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.test_dataloaders if dl]
# determine number of test batches
if self.test_dataloaders is not None:
self._percent_range_check('test_percent_check')
len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
def request_data_loader(self, data_loader_fx):
"""
Handles downloading data in the GPU or TPU case.
:param data_loader_fx:
:return:
"""
# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
data_loader = data_loader_fx()
# all processes wait until data download has happened
dist.barrier()
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
data_loader = data_loader_fx()
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
# regular start
else:
data_loader = data_loader_fx()
return data_loader
def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
f"{overfit_pct:.3f}.")
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct