resolve conflict

resolve failing test
This commit is contained in:
tchaton 2021-02-05 08:49:06 +00:00 committed by Jirka Borovec
parent bb7d188318
commit df5cbf5368
1 changed files with 9 additions and 9 deletions

View File

@ -498,7 +498,7 @@ class ModelCheckpoint(Callback):
def _get_metric_interpolated_filepath_name(
self,
monitor_candidates: Dict[str, Any],
ckpt_name_metrics: Dict[str, Any],
epoch: int,
step: int,
trainer,
@ -506,7 +506,7 @@ class ModelCheckpoint(Callback):
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1
@ -518,7 +518,7 @@ class ModelCheckpoint(Callback):
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates
def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
@ -529,8 +529,8 @@ class ModelCheckpoint(Callback):
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
monitor_candidates,
prefix=self.prefix,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
@ -546,10 +546,10 @@ class ModelCheckpoint(Callback):
else:
self._save_model(last_filepath, trainer, pl_module)
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath