diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index ace69b0234..498b536994 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -23,7 +23,7 @@ import os import shutil import subprocess import time -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import DeviceType, rank_zero_only @@ -101,7 +101,7 @@ class GPUStatsMonitor(Callback): 'temperature': temperature }) - def on_train_start(self, trainer, *args, **kwargs): + def on_train_start(self, trainer, pl_module) -> None: if not trainer.logger: raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') @@ -113,12 +113,14 @@ class GPUStatsMonitor(Callback): self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) - def on_train_epoch_start(self, *args, **kwargs): + def on_train_epoch_start(self, trainer, pl_module) -> None: self._snap_intra_step_time = None self._snap_inter_step_time = None @rank_zero_only - def on_train_batch_start(self, trainer, *args, **kwargs): + def on_train_batch_start( + self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: if self._log_stats.intra_step_time: self._snap_intra_step_time = time.time() @@ -136,7 +138,9 @@ class GPUStatsMonitor(Callback): trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only - def on_train_batch_end(self, trainer, *args, **kwargs): + def on_train_batch_end( + self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5f0318e7ac..a1a44fd70b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -215,7 +215,9 @@ class ModelCheckpoint(Callback): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_train_batch_end(self, trainer, *args, **kwargs) -> None: + def on_train_batch_end( + self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ if self._should_skip_saving_checkpoint(trainer): return @@ -225,7 +227,7 @@ class ModelCheckpoint(Callback): return self.save_checkpoint(trainer) - def on_validation_end(self, trainer, *args, **kwargs) -> None: + def on_validation_end(self, trainer, pl_module) -> None: """ checkpoints can be saved at the end of the val loop """