move state extraction for CaptureMapDataset (#9484)
This commit is contained in:
parent
4cd635347b
commit
847a3fdc07
|
@ -230,9 +230,7 @@ class CaptureMapDataset(Dataset):
|
|||
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
|
||||
self._cached_state_dict = None
|
||||
|
||||
data = self.dataset[item]
|
||||
state_dict = self._state_dict()
|
||||
return data, state_dict
|
||||
return self.dataset[item]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
@ -250,7 +248,7 @@ class CaptureMapDataset(Dataset):
|
|||
}
|
||||
self._cached_state_dict = state_dict
|
||||
|
||||
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
def state_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
return {self.worker_id: {"rng_states": collect_rng_states()}}
|
||||
|
||||
|
||||
|
@ -464,17 +462,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
|
|||
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
|
||||
}
|
||||
"""
|
||||
if isinstance(dataset, CaptureIterableDataset):
|
||||
data = default_collate(samples)
|
||||
metadata = dataset.state_dict()
|
||||
|
||||
elif isinstance(dataset, CaptureMapDataset):
|
||||
samples, states = zip(*samples)
|
||||
data = default_collate(samples)
|
||||
metadata = states[-1]
|
||||
else:
|
||||
return default_collate(samples)
|
||||
|
||||
data = default_collate(samples)
|
||||
if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)):
|
||||
return data
|
||||
metadata = dataset.state_dict()
|
||||
return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue