From 3a2df4f75d104b0ecf0af9b35e8ebd535a99a591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 6 Jan 2022 13:51:02 +0100 Subject: [PATCH] Fix typing in `pl.callbacks.xla_stats_monitor` (#11219) Co-authored-by: Carlos Mocholi --- pyproject.toml | 1 - .../callbacks/xla_stats_monitor.py | 18 +++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ae00f3e51..1c1f615104 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ module = [ "pytorch_lightning.callbacks.progress.tqdm_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.callbacks.stochastic_weight_avg", - "pytorch_lightning.callbacks.xla_stats_monitor", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", "pytorch_lightning.core.lightning", diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index a14ee42e9a..22591592d7 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -20,6 +20,7 @@ Monitor and logs XLA stats during training. """ import time +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -66,7 +67,7 @@ class XLAStatsMonitor(Callback): self._verbose = verbose - def on_train_start(self, trainer, pl_module) -> None: + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.logger: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") @@ -80,11 +81,13 @@ class XLAStatsMonitor(Callback): total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB") - def on_train_epoch_start(self, trainer, pl_module) -> None: + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._start_time = time.time() - def on_train_epoch_end(self, trainer, pl_module) -> None: - logs = {} + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not trainer.logger: + raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") + memory_info = xm.get_memory_info(pl_module.device) epoch_time = time.time() - self._start_time @@ -95,9 +98,10 @@ class XLAStatsMonitor(Callback): peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 epoch_time = trainer.strategy.reduce(epoch_time) - logs["avg. free memory (MB)"] = free_memory - logs["avg. peak memory (MB)"] = peak_memory - trainer.logger.log_metrics(logs, step=trainer.current_epoch) + trainer.logger.log_metrics( + {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, + step=trainer.current_epoch, + ) if self._verbose: rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")