clarify Trainer running state atribs. (#5589)
* update Trainer is_ attributes * tests * more * isort * split * rename * check * fix
This commit is contained in:
parent
671887fd9b
commit
6386f45de7
|
@ -57,7 +57,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
|
||||
def _wrap_optimizers(self, model):
|
||||
trainer = model.trainer
|
||||
if trainer.testing is True:
|
||||
if trainer.testing:
|
||||
return
|
||||
|
||||
self._reinit_with_fairscale_oss(trainer)
|
||||
|
|
|
@ -130,27 +130,3 @@ class DeprecatedDistDeviceAttributes:
|
|||
)
|
||||
if val:
|
||||
self._device_type = DeviceType.GPU
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
# todo: consider rename as `is_training`
|
||||
return self._running_stage == RunningStage.TRAINING
|
||||
|
||||
@training.setter
|
||||
def training(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.TRAINING
|
||||
else:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def testing(self) -> bool:
|
||||
# todo: consider rename as `is_testing`
|
||||
return self._running_stage == RunningStage.TESTING
|
||||
|
||||
@testing.setter
|
||||
def testing(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.TESTING
|
||||
else:
|
||||
self._running_stage = None
|
||||
|
|
|
@ -38,7 +38,6 @@ class EvaluationLoop(object):
|
|||
self.trainer.test_dataloaders = None
|
||||
self.trainer.val_dataloaders = None
|
||||
self.trainer.running_sanity_check = False
|
||||
self.trainer.testing = False
|
||||
|
||||
# when .test() is called, it sets this
|
||||
self.trainer.tested_ckpt_path = None
|
||||
|
|
|
@ -53,7 +53,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerState
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
|
@ -921,3 +921,47 @@ class Trainer(
|
|||
Returns: List of all available plugins that are supported as string arguments.
|
||||
"""
|
||||
return PluginConnector.available_plugins()
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
return self._running_stage == RunningStage.TRAINING
|
||||
|
||||
@training.setter
|
||||
def training(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.TRAINING
|
||||
elif self.training:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def testing(self) -> bool:
|
||||
return self._running_stage == RunningStage.TESTING
|
||||
|
||||
@testing.setter
|
||||
def testing(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.TESTING
|
||||
elif self.testing:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def tuning(self) -> bool:
|
||||
return self._running_stage == RunningStage.TUNING
|
||||
|
||||
@tuning.setter
|
||||
def tuning(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.TUNING
|
||||
elif self.tuning:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def evaluating(self) -> bool:
|
||||
return self._running_stage == RunningStage.EVALUATING
|
||||
|
||||
@evaluating.setter
|
||||
def evaluating(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.EVALUATING
|
||||
elif self.evaluating:
|
||||
self._running_stage = None
|
||||
|
|
|
@ -89,22 +89,6 @@ def test_v1_4_0_deprecated_trainer_device_distrib():
|
|||
assert trainer.use_horovod
|
||||
|
||||
|
||||
def test_v1_4_0_deprecated_trainer_phase():
|
||||
"""Test that Trainer attributes works fine."""
|
||||
trainer = Trainer()
|
||||
|
||||
assert not trainer.training
|
||||
assert not trainer.testing
|
||||
|
||||
trainer.training = True
|
||||
assert trainer.training
|
||||
assert not trainer.testing
|
||||
|
||||
trainer.testing = True
|
||||
assert not trainer.training
|
||||
assert trainer.testing
|
||||
|
||||
|
||||
def test_v1_4_0_deprecated_metrics():
|
||||
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
|
||||
with pytest.deprecated_call(match='will be removed in v1.4'):
|
||||
|
|
Loading…
Reference in New Issue