diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index 00fd79d0f7..7b181fe463 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -100,12 +100,7 @@ class DeviceStatsMonitor(Callback): logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_batch_start( - self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - batch: Any, - batch_idx: int, - unused: Optional[int] = 0, + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: self._get_and_log_device_stats(trainer, "on_train_batch_start")