diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 42c474e68c..44d586886a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -498,7 +498,7 @@ class ModelCheckpoint(Callback): def _get_metric_interpolated_filepath_name( self, - monitor_candidates: Dict[str, Any], + ckpt_name_metrics: Dict[str, Any], epoch: int, step: int, trainer, @@ -506,7 +506,7 @@ class ModelCheckpoint(Callback): ) -> str: filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) - version_cnt = 0 + version_cnt = self.STARTING_VERSION while self.file_exists(filepath, trainer) and filepath != del_filepath: filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) version_cnt += 1 @@ -518,7 +518,7 @@ class ModelCheckpoint(Callback): monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics): should_save_last = self.monitor is None or self.save_last if not should_save_last: return @@ -529,8 +529,8 @@ class ModelCheckpoint(Callback): self.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, - monitor_candidates, - prefix=self.prefix, + ckpt_name_metrics, + prefix=self.prefix ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") else: @@ -546,10 +546,10 @@ class ModelCheckpoint(Callback): else: self._save_model(last_filepath, trainer, pl_module) if ( - self.last_model_path - and self.last_model_path != last_filepath - and (self.save_top_k != -1 or self.save_last) - and trainer.is_global_zero + self.last_model_path + and self.last_model_path != last_filepath + and (self.save_top_k != -1 or self.save_last) + and trainer.is_global_zero ): self._del_model(self.last_model_path) self.last_model_path = last_filepath