diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4a09c0ca1f..8e86d1b722 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from contextlib import suppress from typing import Any, Dict, Optional from pytorch_lightning.loops import Loop @@ -181,8 +180,7 @@ class FitLoop(Loop): self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict) self._dataloader_state_dict = {} - # TODO: specify the possible exception - with suppress(Exception): + if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)