resolve bug (#6781)

This commit is contained in:
thomas chaton 2021-04-01 11:43:23 +01:00 committed by GitHub
parent 13f67ad313
commit 3e3175d074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 12 deletions

View File

@ -118,11 +118,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.global_rank == 0:
time.sleep(2)
self.barrier("end-process")
def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
@ -158,16 +160,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.mp_queue.put(results)
def save(self, state_dict: Dict, path: str) -> None:
"""
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
The rendez-vous doesn't affect directly saving.
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
"""
try:
xm.save(state_dict, path)
except RuntimeError as e:
if "Failed to meet rendezvous" not in str(e):
raise e
xm.save(state_dict, path)
def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()