diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f0c25f6c70..3b79dd30d0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -169,7 +170,7 @@ class LightningModule( return self._automatic_optimization @property - def running_stage(self): + def running_stage(self) -> Optional[RunningStage]: return self.trainer._running_stage if self.trainer else None @automatic_optimization.setter