move state extraction for CaptureMapDataset (#9484)

This commit is contained in:
Adrian Wälchli 2021-09-14 18:04:19 +02:00 committed by GitHub
parent 4cd635347b
commit 847a3fdc07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 15 deletions

View File

@ -230,9 +230,7 @@ class CaptureMapDataset(Dataset):
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
self._cached_state_dict = None
data = self.dataset[item]
state_dict = self._state_dict()
return data, state_dict
return self.dataset[item]
def __len__(self) -> int:
return len(self.dataset)
@ -250,7 +248,7 @@ class CaptureMapDataset(Dataset):
}
self._cached_state_dict = state_dict
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
def state_dict(self) -> Dict[int, Dict[str, Any]]:
return {self.worker_id: {"rng_states": collect_rng_states()}}
@ -464,17 +462,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
}
"""
if isinstance(dataset, CaptureIterableDataset):
data = default_collate(samples)
metadata = dataset.state_dict()
elif isinstance(dataset, CaptureMapDataset):
samples, states = zip(*samples)
data = default_collate(samples)
metadata = states[-1]
else:
return default_collate(samples)
data = default_collate(samples)
if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)):
return data
metadata = dataset.state_dict()
return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}