generalize reinstantiation of dataloader (#1346)
* generalize reinstantiation of dataloader * fix condition * add test * update changelog * fix changelog Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
e68ba1c836
commit
f6a86e8551
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -22,21 +22,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
|
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
|
||||||
- Added testing for python 3.8 ([#915](https://github.com/PyTorchLightning/pytorch-lightning/pull/915))
|
- Added testing for python 3.8 ([#915](https://github.com/PyTorchLightning/pytorch-lightning/pull/915))
|
||||||
- Added a `training_epoch_end` method which is the mirror of `validation_epoch_end`. ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357))
|
- Added a `training_epoch_end` method which is the mirror of `validation_epoch_end`. ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357))
|
||||||
|
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
|
||||||
|
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
|
||||||
|
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
|
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
|
||||||
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
|
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
|
||||||
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
|
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
|
||||||
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
|
|
||||||
- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
|
- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
|
||||||
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
|
|
||||||
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
|
|
||||||
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
|
|
||||||
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
|
|
||||||
- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339))
|
- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339))
|
||||||
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
|
|
||||||
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
|
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
|
||||||
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
|
- Did not always create a DataLoader during reinstantiation, but the same type as before (if subclass of DataLoader) ([#1346](https://github.com/PyTorchLightning/pytorch-lightning/pull/1346))
|
||||||
|
- Did not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
|
||||||
- Remove default Adam optimizer ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
|
- Remove default Adam optimizer ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
|
||||||
- Give warnings for unimplemented required lightning methods ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
|
- Give warnings for unimplemented required lightning methods ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
|
||||||
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
|
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
|
||||||
|
@ -314,6 +313,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU memory utilization
|
- Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU memory utilization
|
||||||
|
- Added SLURM resubmit functionality (port from test-tube)
|
||||||
- Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training
|
- Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training
|
||||||
- Added option to use single gpu per node with `DistributedDataParallel`
|
- Added option to use single gpu per node with `DistributedDataParallel`
|
||||||
|
|
||||||
|
|
|
@ -84,16 +84,10 @@ class TrainerDataLoadingMixin(ABC):
|
||||||
|
|
||||||
if need_dist_sampler and no_sampler_added:
|
if need_dist_sampler and no_sampler_added:
|
||||||
|
|
||||||
|
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
||||||
|
|
||||||
dl_args = {
|
dl_args = {
|
||||||
'dataset': dataloader.dataset,
|
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
|
||||||
'batch_size': dataloader.batch_size,
|
|
||||||
'shuffle': False,
|
|
||||||
'num_workers': dataloader.num_workers,
|
|
||||||
'collate_fn': dataloader.collate_fn,
|
|
||||||
'pin_memory': dataloader.pin_memory,
|
|
||||||
'drop_last': dataloader.drop_last,
|
|
||||||
'timeout': dataloader.timeout,
|
|
||||||
'worker_init_fn': dataloader.worker_init_fn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.use_tpu:
|
if self.use_tpu:
|
||||||
|
@ -102,13 +96,11 @@ class TrainerDataLoadingMixin(ABC):
|
||||||
num_replicas=xm.xrt_world_size(),
|
num_replicas=xm.xrt_world_size(),
|
||||||
rank=xm.get_ordinal()
|
rank=xm.get_ordinal()
|
||||||
)
|
)
|
||||||
dl_args['shuffle'] = False
|
|
||||||
else:
|
else:
|
||||||
sampler = DistributedSampler(dataloader.dataset)
|
sampler = DistributedSampler(dataloader.dataset)
|
||||||
dl_args['shuffle'] = False
|
|
||||||
|
|
||||||
dl_args['sampler'] = sampler
|
dl_args['sampler'] = sampler
|
||||||
dataloader = DataLoader(**dl_args)
|
dataloader = type(dataloader)(**dl_args)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
import tests.base.utils as tutils
|
import tests.base.utils as tutils
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
@ -482,3 +483,41 @@ def test_error_on_zero_len_dataloader(tmpdir):
|
||||||
test_percent_check=0.5
|
test_percent_check=0.5
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
|
||||||
|
def test_dataloader_reinit_for_subclass():
|
||||||
|
|
||||||
|
class CustomDataLoader(torch.utils.data.DataLoader):
|
||||||
|
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
||||||
|
batch_sampler=None, num_workers=0, collate_fn=None,
|
||||||
|
pin_memory=False, drop_last=False, timeout=0,
|
||||||
|
worker_init_fn=None, dummy_kwarg=None):
|
||||||
|
super().__init__(dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle,
|
||||||
|
sampler,
|
||||||
|
batch_sampler,
|
||||||
|
num_workers,
|
||||||
|
collate_fn,
|
||||||
|
pin_memory,
|
||||||
|
drop_last,
|
||||||
|
timeout,
|
||||||
|
worker_init_fn)
|
||||||
|
|
||||||
|
self.dummy_kwarg = dummy_kwarg
|
||||||
|
|
||||||
|
trainer = Trainer(gpus=[0, 1],
|
||||||
|
num_nodes=1,
|
||||||
|
distributed_backend='ddp')
|
||||||
|
|
||||||
|
class CustomDummyObj:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
|
||||||
|
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
|
||||||
|
|
||||||
|
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True)
|
||||||
|
assert isinstance(result, torch.utils.data.DataLoader)
|
||||||
|
assert isinstance(result, CustomDataLoader)
|
||||||
|
assert hasattr(result, 'dummy_kwarg')
|
||||||
|
|
Loading…
Reference in New Issue