diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 510a44ad1b..ec1500ca7a 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -57,7 +57,7 @@ class DDPShardedPlugin(DDPPlugin): def _wrap_optimizers(self, model): trainer = model.trainer - if trainer.testing is True: + if trainer.testing: return self._reinit_with_fairscale_oss(trainer) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index dbfa3258b2..e9407379cb 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -130,27 +130,3 @@ class DeprecatedDistDeviceAttributes: ) if val: self._device_type = DeviceType.GPU - - @property - def training(self) -> bool: - # todo: consider rename as `is_training` - return self._running_stage == RunningStage.TRAINING - - @training.setter - def training(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TRAINING - else: - self._running_stage = None - - @property - def testing(self) -> bool: - # todo: consider rename as `is_testing` - return self._running_stage == RunningStage.TESTING - - @testing.setter - def testing(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TESTING - else: - self._running_stage = None diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a8fa9f4368..af1b42bab0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -38,7 +38,6 @@ class EvaluationLoop(object): self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False - self.trainer.testing = False # when .test() is called, it sets this self.trainer.tested_ckpt_path = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0a2362a438..7d7cec2335 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -921,3 +921,47 @@ class Trainer( Returns: List of all available plugins that are supported as string arguments. """ return PluginConnector.available_plugins() + + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def tuning(self) -> bool: + return self._running_stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TUNING + elif self.tuning: + self._running_stage = None + + @property + def evaluating(self) -> bool: + return self._running_stage == RunningStage.EVALUATING + + @evaluating.setter + def evaluating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.EVALUATING + elif self.evaluating: + self._running_stage = None diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index bb35907417..00f02076fc 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -89,22 +89,6 @@ def test_v1_4_0_deprecated_trainer_device_distrib(): assert trainer.use_horovod -def test_v1_4_0_deprecated_trainer_phase(): - """Test that Trainer attributes works fine.""" - trainer = Trainer() - - assert not trainer.training - assert not trainer.testing - - trainer.training = True - assert trainer.training - assert not trainer.testing - - trainer.testing = True - assert not trainer.training - assert trainer.testing - - def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes with pytest.deprecated_call(match='will be removed in v1.4'):