From c71bd73acb5a89bb2a8ff44beab37fd2ceba352b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sun, 19 Apr 2020 22:58:57 +0200 Subject: [PATCH] DDP sampler (#1513) * Add explicit flag for ddp sampler replacement * Add flag for sampler replacement in ddp * Update data_loading.py * Update CHANGELOG.md * pep8 fixes * pep8 --- CHANGELOG.md | 1 + pytorch_lightning/trainer/data_loading.py | 7 +++---- pytorch_lightning/trainer/trainer.py | 5 +++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c910896884..096488e19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added flag `replace_sampler_ddp` to manually disaple sampler replacement in ddp ([#1513](https://github.com/PyTorchLightning/pytorch-lightning/pull/1513)) - Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems. - Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ca9e09939d..bfbf058407 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -61,6 +61,7 @@ class TrainerDataLoadingMixin(ABC): train_percent_check: float val_percent_check: float test_percent_check: float + replace_sampler_ddp: bool @abstractmethod def is_overriden(self, *args): @@ -88,10 +89,8 @@ class TrainerDataLoadingMixin(ABC): # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): return dataloader - - need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu - - if need_dist_sampler: + need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu) + if self.replace_sampler_ddp and need_dist_sampler: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1964af8168..3c9da89ca5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -127,6 +127,7 @@ class Trainer( benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 @@ -282,6 +283,9 @@ class Trainer( rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. + replace_sampler_ddp: Explicitly enables or disables sampler replacement. + If not specified this will toggled automatically ddp is used + benchmark: If true enables cudnn.benchmark. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the @@ -362,6 +366,7 @@ class Trainer( self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find + self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint