diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index a78e25bfa1..04a123105c 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -262,12 +262,19 @@ class _FitLoop(_Loop): def setup_data_fetcher(self): trainer = self.trainer + pl_module = trainer.lightning_module + combined_loader = self._combined_loader self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING) self._data_fetcher.setup(combined_loader) iter(self._data_fetcher) # creates the iterator inside the fetcher max_batches = sized_len(combined_loader) self.max_batches = max_batches if max_batches is not None else float("inf") + + allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices + if trainer.datamodule is not None: + allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices + has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length) if self.max_batches == 0: