From f6a86e8551e207fb121415eee7fa8e931e0b149f Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 3 Apr 2020 23:55:08 +0200 Subject: [PATCH] generalize reinstantiation of dataloader (#1346) * generalize reinstantiation of dataloader * fix condition * add test * update changelog * fix changelog Co-authored-by: J. Borovec --- CHANGELOG.md | 14 ++++---- pytorch_lightning/trainer/data_loading.py | 16 +++------- tests/trainer/test_dataloaders.py | 39 +++++++++++++++++++++++ 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2cec1027f..71b5c4a687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,21 +22,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) - Added testing for python 3.8 ([#915](https://github.com/PyTorchLightning/pytorch-lightning/pull/915)) - Added a `training_epoch_end` method which is the mirror of `validation_epoch_end`. ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357)) +- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) +- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) +- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) + ### Changed - Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108)) - Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) -- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) - Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) -- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) -- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) -- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) -- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) - Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339)) -- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) - On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319)) -- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318)) +- Did not always create a DataLoader during reinstantiation, but the same type as before (if subclass of DataLoader) ([#1346](https://github.com/PyTorchLightning/pytorch-lightning/pull/1346)) +- Did not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318)) - Remove default Adam optimizer ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317)) - Give warnings for unimplemented required lightning methods ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317)) - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) @@ -314,6 +313,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU memory utilization +- Added SLURM resubmit functionality (port from test-tube) - Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training - Added option to use single gpu per node with `DistributedDataParallel` diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1612a2c5c3..fe1adf75c3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -84,16 +84,10 @@ class TrainerDataLoadingMixin(ABC): if need_dist_sampler and no_sampler_added: + skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] + dl_args = { - 'dataset': dataloader.dataset, - 'batch_size': dataloader.batch_size, - 'shuffle': False, - 'num_workers': dataloader.num_workers, - 'collate_fn': dataloader.collate_fn, - 'pin_memory': dataloader.pin_memory, - 'drop_last': dataloader.drop_last, - 'timeout': dataloader.timeout, - 'worker_init_fn': dataloader.worker_init_fn + k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys } if self.use_tpu: @@ -102,13 +96,11 @@ class TrainerDataLoadingMixin(ABC): num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) - dl_args['shuffle'] = False else: sampler = DistributedSampler(dataloader.dataset) - dl_args['shuffle'] = False dl_args['sampler'] = sampler - dataloader = DataLoader(**dl_args) + dataloader = type(dataloader)(**dl_args) return dataloader diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index fd6f05cc92..d0da0044f2 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1,4 +1,5 @@ import pytest +import torch import tests.base.utils as tutils from pytorch_lightning import Trainer @@ -482,3 +483,41 @@ def test_error_on_zero_len_dataloader(tmpdir): test_percent_check=0.5 ) trainer.fit(model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') +def test_dataloader_reinit_for_subclass(): + + class CustomDataLoader(torch.utils.data.DataLoader): + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, dummy_kwarg=None): + super().__init__(dataset, + batch_size, + shuffle, + sampler, + batch_sampler, + num_workers, + collate_fn, + pin_memory, + drop_last, + timeout, + worker_init_fn) + + self.dummy_kwarg = dummy_kwarg + + trainer = Trainer(gpus=[0, 1], + num_nodes=1, + distributed_backend='ddp') + + class CustomDummyObj: + sampler = None + + result = trainer.auto_add_sampler(CustomDummyObj(), train=True) + assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" + + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True) + assert isinstance(result, torch.utils.data.DataLoader) + assert isinstance(result, CustomDataLoader) + assert hasattr(result, 'dummy_kwarg')