clarify Trainer running state atribs. (#5589)

* update Trainer is_ attributes

* tests

* more

* isort

* split

* rename

* check

* fix
This commit is contained in:
Jirka Borovec 2021-01-24 11:45:42 +01:00 committed by GitHub
parent 671887fd9b
commit 6386f45de7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 43 deletions

View File

@ -57,7 +57,7 @@ class DDPShardedPlugin(DDPPlugin):
def _wrap_optimizers(self, model):
trainer = model.trainer
if trainer.testing is True:
if trainer.testing:
return
self._reinit_with_fairscale_oss(trainer)

View File

@ -130,27 +130,3 @@ class DeprecatedDistDeviceAttributes:
)
if val:
self._device_type = DeviceType.GPU
@property
def training(self) -> bool:
# todo: consider rename as `is_training`
return self._running_stage == RunningStage.TRAINING
@training.setter
def training(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TRAINING
else:
self._running_stage = None
@property
def testing(self) -> bool:
# todo: consider rename as `is_testing`
return self._running_stage == RunningStage.TESTING
@testing.setter
def testing(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TESTING
else:
self._running_stage = None

View File

@ -38,7 +38,6 @@ class EvaluationLoop(object):
self.trainer.test_dataloaders = None
self.trainer.val_dataloaders = None
self.trainer.running_sanity_check = False
self.trainer.testing = False
# when .test() is called, it sets this
self.trainer.tested_ckpt_path = None

View File

@ -53,7 +53,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
@ -921,3 +921,47 @@ class Trainer(
Returns: List of all available plugins that are supported as string arguments.
"""
return PluginConnector.available_plugins()
@property
def training(self) -> bool:
return self._running_stage == RunningStage.TRAINING
@training.setter
def training(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TRAINING
elif self.training:
self._running_stage = None
@property
def testing(self) -> bool:
return self._running_stage == RunningStage.TESTING
@testing.setter
def testing(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TESTING
elif self.testing:
self._running_stage = None
@property
def tuning(self) -> bool:
return self._running_stage == RunningStage.TUNING
@tuning.setter
def tuning(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TUNING
elif self.tuning:
self._running_stage = None
@property
def evaluating(self) -> bool:
return self._running_stage == RunningStage.EVALUATING
@evaluating.setter
def evaluating(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.EVALUATING
elif self.evaluating:
self._running_stage = None

View File

@ -89,22 +89,6 @@ def test_v1_4_0_deprecated_trainer_device_distrib():
assert trainer.use_horovod
def test_v1_4_0_deprecated_trainer_phase():
"""Test that Trainer attributes works fine."""
trainer = Trainer()
assert not trainer.training
assert not trainer.testing
trainer.training = True
assert trainer.training
assert not trainer.testing
trainer.testing = True
assert not trainer.training
assert trainer.testing
def test_v1_4_0_deprecated_metrics():
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
with pytest.deprecated_call(match='will be removed in v1.4'):