diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d207502f9..7d658d4c57 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -929,6 +929,10 @@ class Trainer(TrainerIO): def __train(self): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp: + self.tng_dataloader.sampler.set_epoch(epoch_nb) + # get model model = self.__get_model()