Broadcast dirpath for tighter consistency in model checkpoint callback (#6978)

* Update model_checkpoint.py

* Update model_checkpoint.py

* Update model_checkpoint.py
This commit is contained in:
ananthsub 2021-04-21 10:20:27 -07:00 committed by GitHub
parent e4f3a8d3dd
commit 2f84459d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -552,13 +552,12 @@ class ModelCheckpoint(Callback):
trainer.logger.version
if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}"
)
version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name))
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
self.dirpath = ckpt_path
if not trainer.fast_dev_run and trainer.is_global_zero: