From 2d42ec008f7d6a653caa80f6ebd41171391d7bf3 Mon Sep 17 00:00:00 2001 From: Uladzislau Sazanovich Date: Mon, 24 Aug 2020 17:49:33 +0300 Subject: [PATCH] Make trainer.state a read-only property (#3109) * Make trainer.state a read-only property * Update states.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/states.py | 11 ++++------- pytorch_lightning/trainer/trainer.py | 6 +++++- pytorch_lightning/trainer/training_loop.py | 4 ++-- tests/trainer/test_states.py | 7 +------ 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 461c17e86a..c99c6b3644 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -42,20 +42,17 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[ if not isinstance(self, pytorch_lightning.Trainer): return fn(self, *args, **kwargs) - state_before = self.state + state_before = self._state if entering is not None: - self.state = entering + self._state = entering result = fn(self, *args, **kwargs) # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted # we retain INTERRUPTED state - if self.state == TrainerState.INTERRUPTED: + if self._state == TrainerState.INTERRUPTED: return result - if exiting is not None: - self.state = exiting - else: - self.state = state_before + self._state = exiting if exiting is not None else state_before return result return wrapped_fn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1a526c2807..3a47b8f374 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -402,7 +402,7 @@ class Trainer( self.interrupted = False self.should_stop = False self.running_sanity_check = False - self.state = TrainerState.INITIALIZING + self._state = TrainerState.INITIALIZING self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir @@ -611,6 +611,10 @@ class Trainer( # Callback system self.on_init_end() + @property + def state(self) -> TrainerState: + return self._state + @property def is_global_zero(self) -> bool: return self.global_rank == 0 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index aa1dc3df8c..ecbfee0c45 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -255,7 +255,7 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... - state: TrainerState + _state: TrainerState amp_backend: AMPType on_tpu: bool accelerator_backend: ... @@ -418,7 +418,7 @@ class TrainerTrainLoopMixin(ABC): # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True - self.state = TrainerState.INTERRUPTED + self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 2b2ad545c7..dd4fca4efd 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -31,7 +31,6 @@ def test_state_decorator_nothing_passed(tmpdir): return self.state trainer = Trainer(default_root_dir=tmpdir) - trainer.state = TrainerState.INITIALIZING snapshot_state = test_method(trainer) @@ -47,7 +46,6 @@ def test_state_decorator_entering_only(tmpdir): return self.state trainer = Trainer(default_root_dir=tmpdir) - trainer.state = TrainerState.INITIALIZING snapshot_state = test_method(trainer) @@ -63,7 +61,6 @@ def test_state_decorator_exiting_only(tmpdir): return self.state trainer = Trainer(default_root_dir=tmpdir) - trainer.state = TrainerState.INITIALIZING snapshot_state = test_method(trainer) @@ -79,7 +76,6 @@ def test_state_decorator_entering_and_exiting(tmpdir): return self.state trainer = Trainer(default_root_dir=tmpdir) - trainer.state = TrainerState.INITIALIZING snapshot_state = test_method(trainer) @@ -92,10 +88,9 @@ def test_state_decorator_interrupt(tmpdir): @trainer_state(exiting=TrainerState.FINISHED) def test_method(self): - self.state = TrainerState.INTERRUPTED + self._state = TrainerState.INTERRUPTED trainer = Trainer(default_root_dir=tmpdir) - trainer.state = TrainerState.INITIALIZING test_method(trainer) assert trainer.state == TrainerState.INTERRUPTED