Make trainer.state a read-only property (#3109)
* Make trainer.state a read-only property * Update states.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
8ebf4fe173
commit
2d42ec008f
|
@ -42,20 +42,17 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[
|
|||
if not isinstance(self, pytorch_lightning.Trainer):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
state_before = self.state
|
||||
state_before = self._state
|
||||
if entering is not None:
|
||||
self.state = entering
|
||||
self._state = entering
|
||||
result = fn(self, *args, **kwargs)
|
||||
|
||||
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
|
||||
# we retain INTERRUPTED state
|
||||
if self.state == TrainerState.INTERRUPTED:
|
||||
if self._state == TrainerState.INTERRUPTED:
|
||||
return result
|
||||
|
||||
if exiting is not None:
|
||||
self.state = exiting
|
||||
else:
|
||||
self.state = state_before
|
||||
self._state = exiting if exiting is not None else state_before
|
||||
return result
|
||||
|
||||
return wrapped_fn
|
||||
|
|
|
@ -402,7 +402,7 @@ class Trainer(
|
|||
self.interrupted = False
|
||||
self.should_stop = False
|
||||
self.running_sanity_check = False
|
||||
self.state = TrainerState.INITIALIZING
|
||||
self._state = TrainerState.INITIALIZING
|
||||
|
||||
self._default_root_dir = default_root_dir or os.getcwd()
|
||||
self._weights_save_path = weights_save_path or self._default_root_dir
|
||||
|
@ -611,6 +611,10 @@ class Trainer(
|
|||
# Callback system
|
||||
self.on_init_end()
|
||||
|
||||
@property
|
||||
def state(self) -> TrainerState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_global_zero(self) -> bool:
|
||||
return self.global_rank == 0
|
||||
|
|
|
@ -255,7 +255,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
terminate_on_nan: bool
|
||||
tpu_id: int
|
||||
interactive_ddp_procs: ...
|
||||
state: TrainerState
|
||||
_state: TrainerState
|
||||
amp_backend: AMPType
|
||||
on_tpu: bool
|
||||
accelerator_backend: ...
|
||||
|
@ -418,7 +418,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# user could press ctrl+c many times... only shutdown once
|
||||
if not self.interrupted:
|
||||
self.interrupted = True
|
||||
self.state = TrainerState.INTERRUPTED
|
||||
self._state = TrainerState.INTERRUPTED
|
||||
self.on_keyboard_interrupt()
|
||||
|
||||
self.run_training_teardown()
|
||||
|
|
|
@ -31,7 +31,6 @@ def test_state_decorator_nothing_passed(tmpdir):
|
|||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
|
@ -47,7 +46,6 @@ def test_state_decorator_entering_only(tmpdir):
|
|||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
|
@ -63,7 +61,6 @@ def test_state_decorator_exiting_only(tmpdir):
|
|||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
|
@ -79,7 +76,6 @@ def test_state_decorator_entering_and_exiting(tmpdir):
|
|||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
|
@ -92,10 +88,9 @@ def test_state_decorator_interrupt(tmpdir):
|
|||
|
||||
@trainer_state(exiting=TrainerState.FINISHED)
|
||||
def test_method(self):
|
||||
self.state = TrainerState.INTERRUPTED
|
||||
self._state = TrainerState.INTERRUPTED
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
test_method(trainer)
|
||||
assert trainer.state == TrainerState.INTERRUPTED
|
||||
|
|
Loading…
Reference in New Issue