[fix] [easy] Update Model Checkpoint callback overrides to use base Callback signature (#6908)

* Update model_checkpoint.py
This commit is contained in:
ananthsub 2021-04-09 04:24:59 -07:00 committed by GitHub
parent e35192dfcd
commit 2e53fd3332
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -23,7 +23,7 @@ import os
import shutil
import subprocess
import time
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import DeviceType, rank_zero_only
@ -101,7 +101,7 @@ class GPUStatsMonitor(Callback):
'temperature': temperature
})
def on_train_start(self, trainer, *args, **kwargs):
def on_train_start(self, trainer, pl_module) -> None:
if not trainer.logger:
raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.')
@ -113,12 +113,14 @@ class GPUStatsMonitor(Callback):
self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids))
def on_train_epoch_start(self, *args, **kwargs):
def on_train_epoch_start(self, trainer, pl_module) -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(self, trainer, *args, **kwargs):
def on_train_batch_start(
self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
@ -136,7 +138,9 @@ class GPUStatsMonitor(Callback):
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(self, trainer, *args, **kwargs):
def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()

View File

@ -215,7 +215,9 @@ class ModelCheckpoint(Callback):
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint
def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
if self._should_skip_saving_checkpoint(trainer):
return
@ -225,7 +227,7 @@ class ModelCheckpoint(Callback):
return
self.save_checkpoint(trainer)
def on_validation_end(self, trainer, *args, **kwargs) -> None:
def on_validation_end(self, trainer, pl_module) -> None:
"""
checkpoints can be saved at the end of the val loop
"""