lightning/pytorch_lightning/trainer/data_loading_mixin.py

191 lines
7.6 KiB
Python

import warnings
import torch.distributed as dist
from torch.utils.data import IterableDataset
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class TrainerDataLoadingMixin(object):
def init_train_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.get_train_dataloader = model.train_dataloader
# determine number of training batches
if isinstance(self.get_train_dataloader(), IterableDataset):
self.nb_training_batches = float('inf')
else:
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
else:
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
if msg not in self.shown_warnings and self.proc_rank == 0:
self.shown_warnings.add(msg)
warnings.warn(msg)
def init_val_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.get_val_dataloaders = model.val_dataloader
# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.get_val_dataloaders() is not None:
self.nb_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
for dataloader in self.get_val_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your val_dataloader(s) don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
if msg not in self.shown_warnings and self.proc_rank == 0:
self.shown_warnings.add(msg)
warnings.warn(msg)
break
def init_test_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.get_test_dataloaders = model.test_dataloader
# determine number of test batches
if self.get_test_dataloaders() is not None:
len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
self.nb_test_batches = len_sum
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
self.nb_test_batches = max(1, self.nb_test_batches)
on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_test_dataloaders() is not None:
for dataloader in self.get_test_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your test_dataloader(s) don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
if msg not in self.shown_warnings and self.proc_rank == 0:
self.shown_warnings.add(msg)
warnings.warn(msg)
break
def get_dataloaders(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.init_train_dataloader(model)
self.init_test_dataloader(model)
self.init_val_dataloader(model)
if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
dist.barrier()
# load each dataloader
self.get_train_dataloader()
self.get_test_dataloaders()
self.get_val_dataloaders()
# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Trainer(val_check_interval) must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException(m)
def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct