Refactor dataloading (#955)

* Refactor dataloading

* Refactor dataloading

* Refactor dataloading

* Add shuffle to test
This commit is contained in:
Ethan Harris 2020-02-26 21:55:18 +00:00 committed by GitHub
parent be244560b2
commit b2e9607362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 84 deletions

View File

@ -1,22 +1,10 @@
import warnings
from abc import ABC
import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
from apex import amp
@ -90,36 +78,19 @@ class TrainerDataLoadingMixin(ABC):
model.prepare_data()
def auto_add_sampler(self, dataloader, train):
# do nothing when user gives a sampler
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}
if train:
if self.use_ddp or self.use_ddp2:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
elif self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal()
)
dl_args['shuffle'] = False
else:
sampler = RandomSampler(dataloader.dataset)
# on not train
else:
if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
@ -128,12 +99,16 @@ class TrainerDataLoadingMixin(ABC):
)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
if train:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
dl_args['sampler'] = sampler
dl_args['sampler'] = sampler
new_dataloader = DataLoader(**dl_args)
return new_dataloader
dataloader = DataLoader(**dl_args)
return dataloader
def reset_train_dataloader(self, model):
"""
@ -148,12 +123,12 @@ class TrainerDataLoadingMixin(ABC):
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
# determine number of training batches
if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset):
self._percent_range_check('train_percent_check')
if self.is_infinite_dataloader(self.train_dataloader):
self.num_training_batches = float('inf')
else:
self._percent_range_check('train_percent_check')
# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
@ -168,27 +143,26 @@ class TrainerDataLoadingMixin(ABC):
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)
self._percent_range_check('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
)
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for `train_dataloader`,
`Trainer(val_check_interval)` must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException(m)
def is_iterable_dataloader(self, dataloader):
return (
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
)
def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True
def reset_val_dataloader(self, model):
"""

View File

@ -1114,19 +1114,14 @@ class Trainer(TrainerIOMixin,
self.run_evaluation(test_mode=True)
return
# load the dataloaders
self.reset_train_dataloader(ref_model)
self.reset_val_dataloader(ref_model)
# check if we should run validation during training
self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step')
self.disable_validation = self.disable_validation and not self.fast_dev_run
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if not self.disable_validation and self.num_sanity_val_steps > 0:
self.reset_val_dataloader(ref_model)
# init progress bars for validation sanity check
pbar = tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.val_dataloaders),

View File

@ -271,7 +271,7 @@ class TrainerTrainLoopMixin(ABC):
pass
@abstractmethod
def is_iterable_dataloader(self, dataloader):
def is_infinite_dataloader(self, dataloader):
# this is just empty shell for code from other class
pass
@ -325,6 +325,11 @@ class TrainerTrainLoopMixin(ABC):
# this is just empty shell for code from other class
pass
@abstractmethod
def reset_val_dataloader(self, model):
# this is just empty shell for code from other class
pass
@abstractmethod
def has_arg(self, f_name, arg_name):
# this is just empty shell for code from other class
@ -334,11 +339,17 @@ class TrainerTrainLoopMixin(ABC):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
# Train begin callbacks
self.on_train_start()
# get model
model = self.get_model()
# load data
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)
# Train begin callbacks
model.on_train_start()
self.on_train_start()
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
@ -347,9 +358,6 @@ class TrainerTrainLoopMixin(ABC):
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)
# get model
model = self.get_model()
# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
@ -370,8 +378,8 @@ class TrainerTrainLoopMixin(ABC):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_dataloader(self.train_dataloader):
# for iterable train loader, the progress bar never ends
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
@ -380,7 +388,7 @@ class TrainerTrainLoopMixin(ABC):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
self.main_progress_bar.set_description(desc)
# changing gradient according accumulation_scheduler

View File

@ -168,6 +168,7 @@ class TestModelBase(LightningModule):
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True
)
return loader

View File

@ -380,6 +380,60 @@ def test_model_freeze_unfreeze():
model.unfreeze()
def test_inf_train_dataloader(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
tutils.reset_seed()
class CurrentTestModel(LightningTestModel):
def train_dataloader(self):
dataloader = self._dataloader(train=True)
class CustomInfDataLoader:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count >= 5:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)
return CustomInfDataLoader(dataloader)
hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=50,
)
result = trainer.fit(model)
# verify training completed
assert result == 1
def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
tutils.reset_seed()