diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index a43d9c256c..e59d2485ab 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -87,11 +87,10 @@ class TrainLoop: model = self.trainer.get_model() # set seed for distributed sampler (enables shuffling for each epoch) - # TODO: move to accelerators - if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \ - and hasattr(self.trainer.train_dataloader, 'sampler') \ - and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'): + try: self.trainer.train_dataloader.sampler.set_epoch(epoch) + except Exception: + pass # update training progress in trainer and model model.current_epoch = epoch