mini refactor for _running_stage access (#5724)

* running stage

* circular import

* running stage cleanup

* fix unused import

* fix running stage access

* add return type

* Revert "add return type"

This reverts commit 65b0fe269c.

* try fix typing
This commit is contained in:
Adrian Wälchli 2021-02-22 12:01:54 +01:00 committed by GitHub
parent 423ecf995a
commit 0456b4598f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 27 deletions

View File

@ -24,7 +24,7 @@ from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch import ScriptModule, Tensor
@ -44,6 +44,9 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixi
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
if TYPE_CHECKING:
from pytorch_lightning.trainer.states import RunningStage
class LightningModule(
ABC,
@ -103,7 +106,6 @@ class LightningModule(
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self.running_stage = None
self._automatic_optimization: bool = True
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
@ -169,6 +171,10 @@ class LightningModule(
"""
return self._automatic_optimization
@property
def running_stage(self) -> Optional["RunningStage"]:
return self.trainer._running_stage if self.trainer else None
@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization

View File

@ -59,7 +59,6 @@ from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -450,7 +449,7 @@ class Trainer(
# bookkeeping
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
if self._running_stage is None:
self._set_running_stage(RunningStage.TRAINING, model)
self._running_stage = RunningStage.TRAINING
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)
@ -531,7 +530,7 @@ class Trainer(
if self._state != TrainerState.INTERRUPTED:
self._state = TrainerState.FINISHED
self._set_running_stage(None, model)
self._running_stage = None
return self.accelerator.results or 1
@ -564,14 +563,6 @@ class Trainer(
return results
def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
"""
This function is used to set the running_state on both
the trainer and the model
"""
model_ref.running_stage = stage
self._running_stage = stage
def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")
@ -614,7 +605,7 @@ class Trainer(
self.run_sanity_check(self.lightning_module)
# set stage for logging
self._set_running_stage(RunningStage.TRAINING, self.lightning_module)
self._running_stage = RunningStage.TRAINING
self.checkpoint_connector.has_trained = False
@ -678,9 +669,7 @@ class Trainer(
def run_evaluation(self, max_batches=None, on_epoch=False):
# used to know if we are logging for val, test + reset cached results
self._set_running_stage(
RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module
)
self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING
self.logger_connector.reset()
# bookkeeping
@ -907,7 +896,7 @@ class Trainer(
# --------------------
self.verbose_test = verbose
self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
self._running_stage = RunningStage.TESTING
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
@ -924,7 +913,7 @@ class Trainer(
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
self.teardown('test')
self._set_running_stage(None, model or self.lightning_module)
self._running_stage = None
return results
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
@ -1016,7 +1005,7 @@ class Trainer(
model = model or self.lightning_module
self._set_running_stage(RunningStage.PREDICTING, model)
self._running_stage = RunningStage.PREDICTING
if dataloaders and datamodule:
raise MisconfigurationException(
@ -1033,7 +1022,7 @@ class Trainer(
self.model = model
results = self.fit(model)
self._set_running_stage(None, model)
self._running_stage = None
return results

View File

@ -517,7 +517,7 @@ class TrainLoop:
self.trainer.run_evaluation()
# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
@ -564,7 +564,7 @@ class TrainLoop:
self.trainer.run_evaluation(on_epoch=True)
# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

View File

@ -453,7 +453,7 @@ def test_dp_resume(tmpdir):
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
dp_model.module.module.running_stage = RunningStage.EVALUATING
new_trainer._running_stage = RunningStage.EVALUATING
dataloader = self.train_dataloader()
tpipes.run_prediction(self.trainer.lightning_module, dataloader)

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
import pytest
import torch
@ -103,7 +103,8 @@ def test_lightning_parallel_module_unsqueeze_scalar():
return {"loss": loss}
model = TestModel()
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).cuda()
batch_idx = 0
@ -146,7 +147,8 @@ def test_lightning_parallel_module_python_scalar_conversion(device):
model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0