diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0cb8f522e2..36db9e986b 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -230,9 +230,7 @@ class CaptureMapDataset(Dataset): set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) self._cached_state_dict = None - data = self.dataset[item] - state_dict = self._state_dict() - return data, state_dict + return self.dataset[item] def __len__(self) -> int: return len(self.dataset) @@ -250,7 +248,7 @@ class CaptureMapDataset(Dataset): } self._cached_state_dict = state_dict - def _state_dict(self) -> Dict[int, Dict[str, Any]]: + def state_dict(self) -> Dict[int, Dict[str, Any]]: return {self.worker_id: {"rng_states": collect_rng_states()}} @@ -464,17 +462,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - if isinstance(dataset, CaptureIterableDataset): - data = default_collate(samples) - metadata = dataset.state_dict() - - elif isinstance(dataset, CaptureMapDataset): - samples, states = zip(*samples) - data = default_collate(samples) - metadata = states[-1] - else: - return default_collate(samples) - + data = default_collate(samples) + if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)): + return data + metadata = dataset.state_dict() return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}