resolve bug (#6781)
This commit is contained in:
parent
13f67ad313
commit
3e3175d074
|
@ -118,11 +118,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.__save_end_of_training_weights(self.lightning_module)
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
||||
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
|
||||
self.barrier("end-process")
|
||||
|
||||
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
|
||||
if self.global_rank == 0:
|
||||
time.sleep(2)
|
||||
|
||||
self.barrier("end-process")
|
||||
|
||||
def __save_end_of_training_weights(self, model: LightningModule) -> None:
|
||||
# when training ends on these platforms dump weights to get out of the main process
|
||||
if on_colab_kaggle():
|
||||
|
@ -158,16 +160,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.mp_queue.put(results)
|
||||
|
||||
def save(self, state_dict: Dict, path: str) -> None:
|
||||
"""
|
||||
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
|
||||
The rendez-vous doesn't affect directly saving.
|
||||
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
|
||||
"""
|
||||
try:
|
||||
xm.save(state_dict, path)
|
||||
except RuntimeError as e:
|
||||
if "Failed to meet rendezvous" not in str(e):
|
||||
raise e
|
||||
xm.save(state_dict, path)
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
buffer = io.BytesIO()
|
||||
|
|
Loading…
Reference in New Issue