diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d1e548eefd..6861fbf33b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,22 +1,10 @@ -import warnings from abc import ABC import torch.distributed as dist +from torch.utils.data import SequentialSampler, DataLoader from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler -from pytorch_lightning.utilities.debugging import MisconfigurationException -try: - # loading for pyTorch 1.3 - from torch.utils.data import IterableDataset -except ImportError: - # loading for pyTorch 1.1 - import torch - warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,' - ' please upgrade to 1.2+' % torch.__version__, ImportWarning) - EXIST_ITER_DATASET = False -else: - EXIST_ITER_DATASET = True +from pytorch_lightning.utilities.debugging import MisconfigurationException try: from apex import amp @@ -90,36 +78,19 @@ class TrainerDataLoadingMixin(ABC): model.prepare_data() def auto_add_sampler(self, dataloader, train): - # do nothing when user gives a sampler - dl_args = { - 'dataset': dataloader.dataset, - 'batch_size': dataloader.batch_size, - 'shuffle': False, - 'num_workers': dataloader.num_workers, - 'collate_fn': dataloader.collate_fn, - 'pin_memory': dataloader.pin_memory, - 'drop_last': dataloader.drop_last, - 'timeout': dataloader.timeout, - 'worker_init_fn': dataloader.worker_init_fn - } + if self.use_ddp or self.use_ddp2 or self.use_tpu: + dl_args = { + 'dataset': dataloader.dataset, + 'batch_size': dataloader.batch_size, + 'shuffle': False, + 'num_workers': dataloader.num_workers, + 'collate_fn': dataloader.collate_fn, + 'pin_memory': dataloader.pin_memory, + 'drop_last': dataloader.drop_last, + 'timeout': dataloader.timeout, + 'worker_init_fn': dataloader.worker_init_fn + } - if train: - if self.use_ddp or self.use_ddp2: - sampler = DistributedSampler(dataloader.dataset) - dl_args['shuffle'] = False - - elif self.use_tpu: - sampler = DistributedSampler( - dataloader.dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal() - ) - dl_args['shuffle'] = False - else: - sampler = RandomSampler(dataloader.dataset) - - # on not train - else: if self.use_tpu: sampler = DistributedSampler( dataloader.dataset, @@ -128,12 +99,16 @@ class TrainerDataLoadingMixin(ABC): ) dl_args['shuffle'] = False else: - sampler = SequentialSampler(dataloader.dataset) + if train: + sampler = DistributedSampler(dataloader.dataset) + dl_args['shuffle'] = False + else: + sampler = SequentialSampler(dataloader.dataset) - dl_args['sampler'] = sampler + dl_args['sampler'] = sampler - new_dataloader = DataLoader(**dl_args) - return new_dataloader + dataloader = DataLoader(**dl_args) + return dataloader def reset_train_dataloader(self, model): """ @@ -148,12 +123,12 @@ class TrainerDataLoadingMixin(ABC): # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - # determine number of training batches - if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset): + self._percent_range_check('train_percent_check') + + if self.is_infinite_dataloader(self.train_dataloader): self.num_training_batches = float('inf') else: - self._percent_range_check('train_percent_check') - + # try getting the length self.num_training_batches = len(self.train_dataloader) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) @@ -168,27 +143,26 @@ class TrainerDataLoadingMixin(ABC): f"to the number of the training batches ({self.num_training_batches}). " f"If you want to disable validation set `val_percent_check` to 0.0 instead.") else: + if self.is_infinite_dataloader(self.train_dataloader): + m = ''' + When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader + does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)` + must be an int. An int k specifies checking validation every k training batches. + ''' + raise MisconfigurationException(m) + self._percent_range_check('val_check_interval') self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - # support IterableDataset for train data - self.is_iterable_train_dataloader = ( - EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) - ) - if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int): - m = ''' - When using an iterableDataset for `train_dataloader`, - `Trainer(val_check_interval)` must be an int. - An int k specifies checking validation every k training batches - ''' - raise MisconfigurationException(m) - - def is_iterable_dataloader(self, dataloader): - return ( - EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset) - ) + def is_infinite_dataloader(self, dataloader): + try: + # try getting the length + _ = len(dataloader) + return False + except TypeError as e: + return True def reset_val_dataloader(self, model): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5116df9fc0..562c5bcfa3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1114,19 +1114,14 @@ class Trainer(TrainerIOMixin, self.run_evaluation(test_mode=True) return - # load the dataloaders - self.reset_train_dataloader(ref_model) - self.reset_val_dataloader(ref_model) - # check if we should run validation during training - self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step') - self.disable_validation = self.disable_validation and not self.fast_dev_run + self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run # run tiny validation (if validation defined) # to make sure program won't crash during val ref_model.on_sanity_check_start() - ref_model.on_train_start() if not self.disable_validation and self.num_sanity_val_steps > 0: + self.reset_val_dataloader(ref_model) # init progress bars for validation sanity check pbar = tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.val_dataloaders), diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 847690968b..d2be9894a4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -271,7 +271,7 @@ class TrainerTrainLoopMixin(ABC): pass @abstractmethod - def is_iterable_dataloader(self, dataloader): + def is_infinite_dataloader(self, dataloader): # this is just empty shell for code from other class pass @@ -325,6 +325,11 @@ class TrainerTrainLoopMixin(ABC): # this is just empty shell for code from other class pass + @abstractmethod + def reset_val_dataloader(self, model): + # this is just empty shell for code from other class + pass + @abstractmethod def has_arg(self, f_name, arg_name): # this is just empty shell for code from other class @@ -334,11 +339,17 @@ class TrainerTrainLoopMixin(ABC): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) - # Train begin callbacks - self.on_train_start() - # get model model = self.get_model() + + # load data + self.reset_train_dataloader(model) + self.reset_val_dataloader(model) + + # Train begin callbacks + model.on_train_start() + self.on_train_start() + try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): @@ -347,9 +358,6 @@ class TrainerTrainLoopMixin(ABC): and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) - # get model - model = self.get_model() - # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch @@ -370,8 +378,8 @@ class TrainerTrainLoopMixin(ABC): if self.fast_dev_run: # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run num_iterations = 2 - elif self.is_iterable_dataloader(self.train_dataloader): - # for iterable train loader, the progress bar never ends + elif self.is_infinite_dataloader(self.train_dataloader): + # for infinite train loader, the progress bar never ends num_iterations = None else: num_iterations = self.total_batches @@ -380,7 +388,7 @@ class TrainerTrainLoopMixin(ABC): # .reset() doesn't work on disabled progress bar so we should check if not self.main_progress_bar.disable: self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else '' + desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else '' self.main_progress_bar.set_description(desc) # changing gradient according accumulation_scheduler diff --git a/tests/models/base.py b/tests/models/base.py index e0677da6c3..a4e31354f2 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -168,6 +168,7 @@ class TestModelBase(LightningModule): loader = DataLoader( dataset=dataset, batch_size=batch_size, + shuffle=True ) return loader diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7850638475..7a8881907a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -380,6 +380,60 @@ def test_model_freeze_unfreeze(): model.unfreeze() +def test_inf_train_dataloader(tmpdir): + """Test inf train data loader (e.g. IterableDataset)""" + tutils.reset_seed() + + class CurrentTestModel(LightningTestModel): + def train_dataloader(self): + dataloader = self._dataloader(train=True) + + class CustomInfDataLoader: + def __init__(self, dataloader): + self.dataloader = dataloader + self.iter = iter(dataloader) + self.count = 0 + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count >= 5: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) + + return CustomInfDataLoader(dataloader) + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # fit model + with pytest.raises(MisconfigurationException): + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=0.5 + ) + trainer.fit(model) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=50, + ) + result = trainer.fit(model) + + # verify training completed + assert result == 1 + + def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed()