Fix typing in `pl.callbacks.xla_stats_monitor` (#11219)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-01-06 13:51:02 +01:00 committed by GitHub
parent 9b873dcfcc
commit 3a2df4f75d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 8 deletions

View File

@ -51,7 +51,6 @@ module = [
"pytorch_lightning.callbacks.progress.tqdm_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
"pytorch_lightning.callbacks.xla_stats_monitor",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.lightning",

View File

@ -20,6 +20,7 @@ Monitor and logs XLA stats during training.
"""
import time
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -66,7 +67,7 @@ class XLAStatsMonitor(Callback):
self._verbose = verbose
def on_train_start(self, trainer, pl_module) -> None:
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
@ -80,11 +81,13 @@ class XLAStatsMonitor(Callback):
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
def on_train_epoch_start(self, trainer, pl_module) -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module) -> None:
logs = {}
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
memory_info = xm.get_memory_info(pl_module.device)
epoch_time = time.time() - self._start_time
@ -95,9 +98,10 @@ class XLAStatsMonitor(Callback):
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)
logs["avg. free memory (MB)"] = free_memory
logs["avg. peak memory (MB)"] = peak_memory
trainer.logger.log_metrics(logs, step=trainer.current_epoch)
trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")