[fix] [easy] Update Model Checkpoint callback overrides to use base Callback signature (#6908)
* Update model_checkpoint.py
This commit is contained in:
parent
e35192dfcd
commit
2e53fd3332
|
@ -23,7 +23,7 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import DeviceType, rank_zero_only
|
||||
|
@ -101,7 +101,7 @@ class GPUStatsMonitor(Callback):
|
|||
'temperature': temperature
|
||||
})
|
||||
|
||||
def on_train_start(self, trainer, *args, **kwargs):
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.')
|
||||
|
||||
|
@ -113,12 +113,14 @@ class GPUStatsMonitor(Callback):
|
|||
|
||||
self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids))
|
||||
|
||||
def on_train_epoch_start(self, *args, **kwargs):
|
||||
def on_train_epoch_start(self, trainer, pl_module) -> None:
|
||||
self._snap_intra_step_time = None
|
||||
self._snap_inter_step_time = None
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_batch_start(self, trainer, *args, **kwargs):
|
||||
def on_train_batch_start(
|
||||
self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
if self._log_stats.intra_step_time:
|
||||
self._snap_intra_step_time = time.time()
|
||||
|
||||
|
@ -136,7 +138,9 @@ class GPUStatsMonitor(Callback):
|
|||
trainer.logger.log_metrics(logs, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_batch_end(self, trainer, *args, **kwargs):
|
||||
def on_train_batch_end(
|
||||
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
if self._log_stats.inter_step_time:
|
||||
self._snap_inter_step_time = time.time()
|
||||
|
||||
|
|
|
@ -215,7 +215,9 @@ class ModelCheckpoint(Callback):
|
|||
self.__resolve_ckpt_dir(trainer)
|
||||
self.save_function = trainer.save_checkpoint
|
||||
|
||||
def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
|
||||
def on_train_batch_end(
|
||||
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
|
||||
if self._should_skip_saving_checkpoint(trainer):
|
||||
return
|
||||
|
@ -225,7 +227,7 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
self.save_checkpoint(trainer)
|
||||
|
||||
def on_validation_end(self, trainer, *args, **kwargs) -> None:
|
||||
def on_validation_end(self, trainer, pl_module) -> None:
|
||||
"""
|
||||
checkpoints can be saved at the end of the val loop
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue