From c0f3b6b035f955fc371dec412d3816712f3fc1dd Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Mon, 16 Sep 2019 10:21:00 -0400 Subject: [PATCH] added set_epoch for distributed sampler, fix for #224 (#225) --- pytorch_lightning/trainer/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d207502f9..7d658d4c57 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -929,6 +929,10 @@ class Trainer(TrainerIO): def __train(self): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp: + self.tng_dataloader.sampler.set_epoch(epoch_nb) + # get model model = self.__get_model()