mini refactor for _running_stage access (#5724)
* running stage
* circular import
* running stage cleanup
* fix unused import
* fix running stage access
* add return type
* Revert "add return type"
This reverts commit 65b0fe269c
.
* try fix typing
This commit is contained in:
parent
423ecf995a
commit
0456b4598f
|
@ -24,7 +24,7 @@ from abc import ABC
|
|||
from argparse import Namespace
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import ScriptModule, Tensor
|
||||
|
@ -44,6 +44,9 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixi
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
|
||||
|
||||
class LightningModule(
|
||||
ABC,
|
||||
|
@ -103,7 +106,6 @@ class LightningModule(
|
|||
self._running_manual_backward = False
|
||||
self._current_hook_fx_name = None
|
||||
self._current_dataloader_idx = None
|
||||
self.running_stage = None
|
||||
self._automatic_optimization: bool = True
|
||||
|
||||
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
|
||||
|
@ -169,6 +171,10 @@ class LightningModule(
|
|||
"""
|
||||
return self._automatic_optimization
|
||||
|
||||
@property
|
||||
def running_stage(self) -> Optional["RunningStage"]:
|
||||
return self.trainer._running_stage if self.trainer else None
|
||||
|
||||
@automatic_optimization.setter
|
||||
def automatic_optimization(self, automatic_optimization: bool) -> None:
|
||||
self._automatic_optimization = automatic_optimization
|
||||
|
|
|
@ -59,7 +59,6 @@ from pytorch_lightning.tuner.tuning import Tuner
|
|||
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.enums import LightningEnum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -450,7 +449,7 @@ class Trainer(
|
|||
# bookkeeping
|
||||
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
|
||||
if self._running_stage is None:
|
||||
self._set_running_stage(RunningStage.TRAINING, model)
|
||||
self._running_stage = RunningStage.TRAINING
|
||||
|
||||
# set local properties on the model
|
||||
self.model_connector.copy_trainer_model_properties(model)
|
||||
|
@ -531,7 +530,7 @@ class Trainer(
|
|||
if self._state != TrainerState.INTERRUPTED:
|
||||
self._state = TrainerState.FINISHED
|
||||
|
||||
self._set_running_stage(None, model)
|
||||
self._running_stage = None
|
||||
|
||||
return self.accelerator.results or 1
|
||||
|
||||
|
@ -564,14 +563,6 @@ class Trainer(
|
|||
|
||||
return results
|
||||
|
||||
def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
|
||||
"""
|
||||
This function is used to set the running_state on both
|
||||
the trainer and the model
|
||||
"""
|
||||
model_ref.running_stage = stage
|
||||
self._running_stage = stage
|
||||
|
||||
def _pre_training_routine(self):
|
||||
# wait for all to join if on distributed
|
||||
self.accelerator.barrier("setup_training")
|
||||
|
@ -614,7 +605,7 @@ class Trainer(
|
|||
self.run_sanity_check(self.lightning_module)
|
||||
|
||||
# set stage for logging
|
||||
self._set_running_stage(RunningStage.TRAINING, self.lightning_module)
|
||||
self._running_stage = RunningStage.TRAINING
|
||||
|
||||
self.checkpoint_connector.has_trained = False
|
||||
|
||||
|
@ -678,9 +669,7 @@ class Trainer(
|
|||
def run_evaluation(self, max_batches=None, on_epoch=False):
|
||||
|
||||
# used to know if we are logging for val, test + reset cached results
|
||||
self._set_running_stage(
|
||||
RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module
|
||||
)
|
||||
self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING
|
||||
self.logger_connector.reset()
|
||||
|
||||
# bookkeeping
|
||||
|
@ -907,7 +896,7 @@ class Trainer(
|
|||
# --------------------
|
||||
self.verbose_test = verbose
|
||||
|
||||
self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
|
||||
self._running_stage = RunningStage.TESTING
|
||||
|
||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||
if test_dataloaders and datamodule:
|
||||
|
@ -924,7 +913,7 @@ class Trainer(
|
|||
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
|
||||
|
||||
self.teardown('test')
|
||||
self._set_running_stage(None, model or self.lightning_module)
|
||||
self._running_stage = None
|
||||
return results
|
||||
|
||||
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
|
||||
|
@ -1016,7 +1005,7 @@ class Trainer(
|
|||
|
||||
model = model or self.lightning_module
|
||||
|
||||
self._set_running_stage(RunningStage.PREDICTING, model)
|
||||
self._running_stage = RunningStage.PREDICTING
|
||||
|
||||
if dataloaders and datamodule:
|
||||
raise MisconfigurationException(
|
||||
|
@ -1033,7 +1022,7 @@ class Trainer(
|
|||
|
||||
self.model = model
|
||||
results = self.fit(model)
|
||||
self._set_running_stage(None, model)
|
||||
self._running_stage = None
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -517,7 +517,7 @@ class TrainLoop:
|
|||
self.trainer.run_evaluation()
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
|
||||
self.trainer._running_stage = RunningStage.TRAINING
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
|
@ -564,7 +564,7 @@ class TrainLoop:
|
|||
self.trainer.run_evaluation(on_epoch=True)
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
|
||||
self.trainer._running_stage = RunningStage.TRAINING
|
||||
|
||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
|
|
|
@ -453,7 +453,7 @@ def test_dp_resume(tmpdir):
|
|||
# haven't trained with the new loaded model
|
||||
dp_model = new_trainer.model
|
||||
dp_model.eval()
|
||||
dp_model.module.module.running_stage = RunningStage.EVALUATING
|
||||
new_trainer._running_stage = RunningStage.EVALUATING
|
||||
|
||||
dataloader = self.train_dataloader()
|
||||
tpipes.run_prediction(self.trainer.lightning_module, dataloader)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -103,7 +103,8 @@ def test_lightning_parallel_module_unsqueeze_scalar():
|
|||
return {"loss": loss}
|
||||
|
||||
model = TestModel()
|
||||
model.running_stage = RunningStage.TRAINING
|
||||
model.trainer = Mock()
|
||||
model.trainer._running_stage = RunningStage.TRAINING
|
||||
batch = torch.rand(2, 32).cuda()
|
||||
batch_idx = 0
|
||||
|
||||
|
@ -146,7 +147,8 @@ def test_lightning_parallel_module_python_scalar_conversion(device):
|
|||
|
||||
model = TestModel()
|
||||
model.to(device)
|
||||
model.running_stage = RunningStage.TRAINING
|
||||
model.trainer = Mock()
|
||||
model.trainer._running_stage = RunningStage.TRAINING
|
||||
batch = torch.rand(2, 32).to(device)
|
||||
batch_idx = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue