Make trainer.state a read-only property (#3109)

* Make trainer.state a read-only property

* Update states.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Uladzislau Sazanovich 2020-08-24 17:49:33 +03:00 committed by GitHub
parent 8ebf4fe173
commit 2d42ec008f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 16 deletions

View File

@ -42,20 +42,17 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)
state_before = self.state
state_before = self._state
if entering is not None:
self.state = entering
self._state = entering
result = fn(self, *args, **kwargs)
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if self.state == TrainerState.INTERRUPTED:
if self._state == TrainerState.INTERRUPTED:
return result
if exiting is not None:
self.state = exiting
else:
self.state = state_before
self._state = exiting if exiting is not None else state_before
return result
return wrapped_fn

View File

@ -402,7 +402,7 @@ class Trainer(
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING
self._state = TrainerState.INITIALIZING
self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir
@ -611,6 +611,10 @@ class Trainer(
# Callback system
self.on_init_end()
@property
def state(self) -> TrainerState:
return self._state
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0

View File

@ -255,7 +255,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState
_state: TrainerState
amp_backend: AMPType
on_tpu: bool
accelerator_backend: ...
@ -418,7 +418,7 @@ class TrainerTrainLoopMixin(ABC):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.state = TrainerState.INTERRUPTED
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()
self.run_training_teardown()

View File

@ -31,7 +31,6 @@ def test_state_decorator_nothing_passed(tmpdir):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
@ -47,7 +46,6 @@ def test_state_decorator_entering_only(tmpdir):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
@ -63,7 +61,6 @@ def test_state_decorator_exiting_only(tmpdir):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
@ -79,7 +76,6 @@ def test_state_decorator_entering_and_exiting(tmpdir):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
@ -92,10 +88,9 @@ def test_state_decorator_interrupt(tmpdir):
@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED
self._state = TrainerState.INTERRUPTED
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED