Fix typing in `pl.callbacks.xla_stats_monitor` (#11219)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
9b873dcfcc
commit
3a2df4f75d
|
@ -51,7 +51,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.tqdm_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.callbacks.stochastic_weight_avg",
|
||||
"pytorch_lightning.callbacks.xla_stats_monitor",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.core.decorators",
|
||||
"pytorch_lightning.core.lightning",
|
||||
|
|
|
@ -20,6 +20,7 @@ Monitor and logs XLA stats during training.
|
|||
"""
|
||||
import time
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -66,7 +67,7 @@ class XLAStatsMonitor(Callback):
|
|||
|
||||
self._verbose = verbose
|
||||
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
|
@ -80,11 +81,13 @@ class XLAStatsMonitor(Callback):
|
|||
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
|
||||
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module) -> None:
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module) -> None:
|
||||
logs = {}
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
memory_info = xm.get_memory_info(pl_module.device)
|
||||
epoch_time = time.time() - self._start_time
|
||||
|
||||
|
@ -95,9 +98,10 @@ class XLAStatsMonitor(Callback):
|
|||
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
|
||||
epoch_time = trainer.strategy.reduce(epoch_time)
|
||||
|
||||
logs["avg. free memory (MB)"] = free_memory
|
||||
logs["avg. peak memory (MB)"] = peak_memory
|
||||
trainer.logger.log_metrics(logs, step=trainer.current_epoch)
|
||||
trainer.logger.log_metrics(
|
||||
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
|
||||
step=trainer.current_epoch,
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
|
|
Loading…
Reference in New Issue