Refactor dataloading (#955)
* Refactor dataloading * Refactor dataloading * Refactor dataloading * Add shuffle to test
This commit is contained in:
parent
be244560b2
commit
b2e9607362
|
@ -1,22 +1,10 @@
|
|||
import warnings
|
||||
from abc import ABC
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import SequentialSampler, DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
# loading for pyTorch 1.3
|
||||
from torch.utils.data import IterableDataset
|
||||
except ImportError:
|
||||
# loading for pyTorch 1.1
|
||||
import torch
|
||||
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
|
||||
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
|
||||
EXIST_ITER_DATASET = False
|
||||
else:
|
||||
EXIST_ITER_DATASET = True
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -90,7 +78,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
model.prepare_data()
|
||||
|
||||
def auto_add_sampler(self, dataloader, train):
|
||||
# do nothing when user gives a sampler
|
||||
if self.use_ddp or self.use_ddp2 or self.use_tpu:
|
||||
dl_args = {
|
||||
'dataset': dataloader.dataset,
|
||||
'batch_size': dataloader.batch_size,
|
||||
|
@ -103,23 +91,6 @@ class TrainerDataLoadingMixin(ABC):
|
|||
'worker_init_fn': dataloader.worker_init_fn
|
||||
}
|
||||
|
||||
if train:
|
||||
if self.use_ddp or self.use_ddp2:
|
||||
sampler = DistributedSampler(dataloader.dataset)
|
||||
dl_args['shuffle'] = False
|
||||
|
||||
elif self.use_tpu:
|
||||
sampler = DistributedSampler(
|
||||
dataloader.dataset,
|
||||
num_replicas=xm.xrt_world_size(),
|
||||
rank=xm.get_ordinal()
|
||||
)
|
||||
dl_args['shuffle'] = False
|
||||
else:
|
||||
sampler = RandomSampler(dataloader.dataset)
|
||||
|
||||
# on not train
|
||||
else:
|
||||
if self.use_tpu:
|
||||
sampler = DistributedSampler(
|
||||
dataloader.dataset,
|
||||
|
@ -127,13 +98,17 @@ class TrainerDataLoadingMixin(ABC):
|
|||
rank=xm.get_ordinal()
|
||||
)
|
||||
dl_args['shuffle'] = False
|
||||
else:
|
||||
if train:
|
||||
sampler = DistributedSampler(dataloader.dataset)
|
||||
dl_args['shuffle'] = False
|
||||
else:
|
||||
sampler = SequentialSampler(dataloader.dataset)
|
||||
|
||||
dl_args['sampler'] = sampler
|
||||
|
||||
new_dataloader = DataLoader(**dl_args)
|
||||
return new_dataloader
|
||||
dataloader = DataLoader(**dl_args)
|
||||
return dataloader
|
||||
|
||||
def reset_train_dataloader(self, model):
|
||||
"""
|
||||
|
@ -148,12 +123,12 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# automatically add samplers
|
||||
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
||||
|
||||
# determine number of training batches
|
||||
if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset):
|
||||
self.num_training_batches = float('inf')
|
||||
else:
|
||||
self._percent_range_check('train_percent_check')
|
||||
|
||||
if self.is_infinite_dataloader(self.train_dataloader):
|
||||
self.num_training_batches = float('inf')
|
||||
else:
|
||||
# try getting the length
|
||||
self.num_training_batches = len(self.train_dataloader)
|
||||
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
|
||||
|
||||
|
@ -168,27 +143,26 @@ class TrainerDataLoadingMixin(ABC):
|
|||
f"to the number of the training batches ({self.num_training_batches}). "
|
||||
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
|
||||
else:
|
||||
if self.is_infinite_dataloader(self.train_dataloader):
|
||||
m = '''
|
||||
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
|
||||
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
|
||||
must be an int. An int k specifies checking validation every k training batches.
|
||||
'''
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
self._percent_range_check('val_check_interval')
|
||||
|
||||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
# support IterableDataset for train data
|
||||
self.is_iterable_train_dataloader = (
|
||||
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
|
||||
)
|
||||
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
|
||||
m = '''
|
||||
When using an iterableDataset for `train_dataloader`,
|
||||
`Trainer(val_check_interval)` must be an int.
|
||||
An int k specifies checking validation every k training batches
|
||||
'''
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
def is_iterable_dataloader(self, dataloader):
|
||||
return (
|
||||
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
|
||||
)
|
||||
def is_infinite_dataloader(self, dataloader):
|
||||
try:
|
||||
# try getting the length
|
||||
_ = len(dataloader)
|
||||
return False
|
||||
except TypeError as e:
|
||||
return True
|
||||
|
||||
def reset_val_dataloader(self, model):
|
||||
"""
|
||||
|
|
|
@ -1114,19 +1114,14 @@ class Trainer(TrainerIOMixin,
|
|||
self.run_evaluation(test_mode=True)
|
||||
return
|
||||
|
||||
# load the dataloaders
|
||||
self.reset_train_dataloader(ref_model)
|
||||
self.reset_val_dataloader(ref_model)
|
||||
|
||||
# check if we should run validation during training
|
||||
self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step')
|
||||
self.disable_validation = self.disable_validation and not self.fast_dev_run
|
||||
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
|
||||
|
||||
# run tiny validation (if validation defined)
|
||||
# to make sure program won't crash during val
|
||||
ref_model.on_sanity_check_start()
|
||||
ref_model.on_train_start()
|
||||
if not self.disable_validation and self.num_sanity_val_steps > 0:
|
||||
self.reset_val_dataloader(ref_model)
|
||||
# init progress bars for validation sanity check
|
||||
pbar = tqdm(desc='Validation sanity check',
|
||||
total=self.num_sanity_val_steps * len(self.val_dataloaders),
|
||||
|
|
|
@ -271,7 +271,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_iterable_dataloader(self, dataloader):
|
||||
def is_infinite_dataloader(self, dataloader):
|
||||
# this is just empty shell for code from other class
|
||||
pass
|
||||
|
||||
|
@ -325,6 +325,11 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# this is just empty shell for code from other class
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_val_dataloader(self, model):
|
||||
# this is just empty shell for code from other class
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_arg(self, f_name, arg_name):
|
||||
# this is just empty shell for code from other class
|
||||
|
@ -334,11 +339,17 @@ class TrainerTrainLoopMixin(ABC):
|
|||
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
|
||||
# Train begin callbacks
|
||||
self.on_train_start()
|
||||
|
||||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
# load data
|
||||
self.reset_train_dataloader(model)
|
||||
self.reset_val_dataloader(model)
|
||||
|
||||
# Train begin callbacks
|
||||
model.on_train_start()
|
||||
self.on_train_start()
|
||||
|
||||
try:
|
||||
# run all epochs
|
||||
for epoch in range(self.current_epoch, self.max_epochs):
|
||||
|
@ -347,9 +358,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
|
||||
self.train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
self.current_epoch = epoch
|
||||
|
@ -370,8 +378,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.fast_dev_run:
|
||||
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
|
||||
num_iterations = 2
|
||||
elif self.is_iterable_dataloader(self.train_dataloader):
|
||||
# for iterable train loader, the progress bar never ends
|
||||
elif self.is_infinite_dataloader(self.train_dataloader):
|
||||
# for infinite train loader, the progress bar never ends
|
||||
num_iterations = None
|
||||
else:
|
||||
num_iterations = self.total_batches
|
||||
|
@ -380,7 +388,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# .reset() doesn't work on disabled progress bar so we should check
|
||||
if not self.main_progress_bar.disable:
|
||||
self.main_progress_bar.reset(num_iterations)
|
||||
desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else ''
|
||||
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
|
||||
self.main_progress_bar.set_description(desc)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
|
|
|
@ -168,6 +168,7 @@ class TestModelBase(LightningModule):
|
|||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
return loader
|
||||
|
|
|
@ -380,6 +380,60 @@ def test_model_freeze_unfreeze():
|
|||
model.unfreeze()
|
||||
|
||||
|
||||
def test_inf_train_dataloader(tmpdir):
|
||||
"""Test inf train data loader (e.g. IterableDataset)"""
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(LightningTestModel):
|
||||
def train_dataloader(self):
|
||||
dataloader = self._dataloader(train=True)
|
||||
|
||||
class CustomInfDataLoader:
|
||||
def __init__(self, dataloader):
|
||||
self.dataloader = dataloader
|
||||
self.iter = iter(dataloader)
|
||||
self.count = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.count >= 5:
|
||||
raise StopIteration
|
||||
self.count = self.count + 1
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
self.iter = iter(self.dataloader)
|
||||
return next(self.iter)
|
||||
|
||||
return CustomInfDataLoader(dataloader)
|
||||
|
||||
hparams = tutils.get_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# fit model
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=0.5
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=50,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_multiple_val_dataloader(tmpdir):
|
||||
"""Verify multiple val_dataloader."""
|
||||
tutils.reset_seed()
|
||||
|
|
Loading…
Reference in New Issue