262 lines
9.3 KiB
Python
262 lines
9.3 KiB
Python
from abc import ABC
|
|
|
|
import torch.distributed as dist
|
|
from torch.utils.data import SequentialSampler, DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
|
|
|
try:
|
|
from apex import amp
|
|
|
|
APEX_AVAILABLE = True
|
|
except ImportError:
|
|
APEX_AVAILABLE = False
|
|
|
|
try:
|
|
import torch_xla
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
|
|
|
XLA_AVAILABLE = True
|
|
except ImportError:
|
|
XLA_AVAILABLE = False
|
|
|
|
|
|
class TrainerDataLoadingMixin(ABC):
|
|
|
|
def __init__(self):
|
|
# this is just a summary on variables used in this abstract class,
|
|
# the proper values/initialisation should be done in child class
|
|
self.proc_rank = None
|
|
self.use_ddp = None
|
|
self.use_ddp2 = None
|
|
self.shown_warnings = None
|
|
self.val_check_interval = None
|
|
self.use_tpu = None
|
|
self.tpu_local_core_rank = None
|
|
self.train_dataloader = None
|
|
self.num_training_batches = None
|
|
self.val_check_batch = None
|
|
self.val_dataloaders = None
|
|
self.num_val_batches = None
|
|
self.test_dataloaders = None
|
|
self.num_test_batches = None
|
|
|
|
def _percent_range_check(self, name):
|
|
value = getattr(self, name)
|
|
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
|
|
if name == "val_check_interval":
|
|
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
|
|
|
|
if not 0. <= value <= 1.:
|
|
raise ValueError(msg)
|
|
|
|
def call_prepare_data(self, model):
|
|
"""
|
|
Let model download the data on proc==0 only
|
|
:param model:
|
|
"""
|
|
# download data on DDP+
|
|
if self.use_ddp or self.use_ddp2:
|
|
if self.proc_rank == 0:
|
|
model.prepare_data()
|
|
|
|
# all processes wait until data download has happened
|
|
dist.barrier()
|
|
|
|
# data download/load on TPU
|
|
elif self.use_tpu and XLA_AVAILABLE:
|
|
if self.tpu_local_core_rank == 0:
|
|
model.prepare_data()
|
|
|
|
# all processes wait until data download has happened
|
|
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
|
|
|
|
else:
|
|
# regular download
|
|
model.prepare_data()
|
|
|
|
def auto_add_sampler(self, dataloader, train):
|
|
if self.use_ddp or self.use_ddp2 or self.use_tpu:
|
|
dl_args = {
|
|
'dataset': dataloader.dataset,
|
|
'batch_size': dataloader.batch_size,
|
|
'shuffle': False,
|
|
'num_workers': dataloader.num_workers,
|
|
'collate_fn': dataloader.collate_fn,
|
|
'pin_memory': dataloader.pin_memory,
|
|
'drop_last': dataloader.drop_last,
|
|
'timeout': dataloader.timeout,
|
|
'worker_init_fn': dataloader.worker_init_fn
|
|
}
|
|
|
|
if self.use_tpu:
|
|
sampler = DistributedSampler(
|
|
dataloader.dataset,
|
|
num_replicas=xm.xrt_world_size(),
|
|
rank=xm.get_ordinal()
|
|
)
|
|
dl_args['shuffle'] = False
|
|
else:
|
|
if train:
|
|
sampler = DistributedSampler(dataloader.dataset)
|
|
dl_args['shuffle'] = False
|
|
else:
|
|
sampler = SequentialSampler(dataloader.dataset)
|
|
|
|
dl_args['sampler'] = sampler
|
|
|
|
dataloader = DataLoader(**dl_args)
|
|
return dataloader
|
|
|
|
def reset_train_dataloader(self, model):
|
|
"""
|
|
Dataloaders are provided by the model
|
|
:param model:
|
|
:return:
|
|
"""
|
|
|
|
self.train_dataloader = self.request_data_loader(model.train_dataloader)
|
|
self.num_training_batches = 0
|
|
|
|
# automatically add samplers
|
|
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
|
|
|
self._percent_range_check('train_percent_check')
|
|
|
|
if self.is_infinite_dataloader(self.train_dataloader):
|
|
self.num_training_batches = float('inf')
|
|
else:
|
|
# try getting the length
|
|
self.num_training_batches = len(self.train_dataloader)
|
|
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
|
|
|
|
# determine when to check validation
|
|
# if int passed in, val checks that often
|
|
# otherwise, it checks in [0, 1.0] % range of a training epoch
|
|
if isinstance(self.val_check_interval, int):
|
|
self.val_check_batch = self.val_check_interval
|
|
if self.val_check_batch > self.num_training_batches:
|
|
raise ValueError(
|
|
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
|
|
f"to the number of the training batches ({self.num_training_batches}). "
|
|
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
|
|
else:
|
|
if self.is_infinite_dataloader(self.train_dataloader):
|
|
m = '''
|
|
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
|
|
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
|
|
must be an int. An int k specifies checking validation every k training batches.
|
|
'''
|
|
raise MisconfigurationException(m)
|
|
|
|
self._percent_range_check('val_check_interval')
|
|
|
|
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
|
self.val_check_batch = max(1, self.val_check_batch)
|
|
|
|
def is_infinite_dataloader(self, dataloader):
|
|
try:
|
|
# try getting the length
|
|
_ = len(dataloader)
|
|
return False
|
|
except TypeError as e:
|
|
return True
|
|
|
|
def reset_val_dataloader(self, model):
|
|
"""
|
|
Dataloaders are provided by the model
|
|
:param model:
|
|
:return:
|
|
"""
|
|
if not self.is_overriden('validation_step'):
|
|
return
|
|
|
|
self.val_dataloaders = self.request_data_loader(model.val_dataloader)
|
|
if not isinstance(self.val_dataloaders, list):
|
|
self.val_dataloaders = [self.val_dataloaders]
|
|
self.num_val_batches = 0
|
|
|
|
# add samplers
|
|
self.val_dataloaders = [self.auto_add_sampler(dl, train=False)
|
|
for dl in self.val_dataloaders if dl]
|
|
|
|
# determine number of validation batches
|
|
# val datasets could be none, 1 or 2+
|
|
if self.val_dataloaders is not None:
|
|
self._percent_range_check('val_percent_check')
|
|
|
|
self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
|
|
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
|
|
|
|
def reset_test_dataloader(self, model):
|
|
"""Dataloaders are provided by the model.
|
|
|
|
:param model:
|
|
"""
|
|
if not self.is_overriden('test_step'):
|
|
return
|
|
|
|
# get actual loader
|
|
self.test_dataloaders = self.request_data_loader(model.test_dataloader)
|
|
if not isinstance(self.test_dataloaders, list):
|
|
self.test_dataloaders = [self.test_dataloaders]
|
|
self.num_test_batches = 0
|
|
|
|
# add samplers
|
|
self.test_dataloaders = [self.auto_add_sampler(dl, train=False)
|
|
for dl in self.test_dataloaders if dl]
|
|
|
|
# determine number of test batches
|
|
if self.test_dataloaders is not None:
|
|
self._percent_range_check('test_percent_check')
|
|
|
|
len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
|
|
self.num_test_batches = len_sum
|
|
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
|
|
|
|
def request_data_loader(self, data_loader_fx):
|
|
"""
|
|
Handles downloading data in the GPU or TPU case.
|
|
|
|
:param data_loader_fx:
|
|
:return:
|
|
"""
|
|
# get the function we'll use to get data
|
|
if self.use_ddp or self.use_ddp2:
|
|
data_loader = data_loader_fx()
|
|
|
|
# all processes wait until data download has happened
|
|
dist.barrier()
|
|
|
|
# data download/load on TPU
|
|
elif self.use_tpu and XLA_AVAILABLE:
|
|
data_loader = data_loader_fx()
|
|
|
|
# all processes wait until data download has happened
|
|
torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders")
|
|
|
|
# regular start
|
|
else:
|
|
data_loader = data_loader_fx()
|
|
|
|
return data_loader
|
|
|
|
def determine_data_use_amount(self, train_percent_check, val_percent_check,
|
|
test_percent_check, overfit_pct):
|
|
"""
|
|
Use less data for debugging purposes
|
|
"""
|
|
self.train_percent_check = train_percent_check
|
|
self.val_percent_check = val_percent_check
|
|
self.test_percent_check = test_percent_check
|
|
if overfit_pct > 0:
|
|
if overfit_pct > 1:
|
|
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
|
|
f"{overfit_pct:.3f}.")
|
|
|
|
self.train_percent_check = overfit_pct
|
|
self.val_percent_check = overfit_pct
|
|
self.test_percent_check = overfit_pct
|