Add tracking of basic states in Trainer [wip - to-be-merged after v0.9] (#2541)

* Add initial tracking of states in Trainer.

* Add INTERRUPTED state, improve tests, move state switching from callback to a trainer.

* Move part of a trainer state switching to a decorator.

* Add documentation.

* Fix docs, rename state enum, restore state to previous on exit if None, add tests for decorator only.

* Fix callback typing.

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Uladzislau Sazanovich 2020-08-09 13:24:09 +03:00 committed by GitHub
parent 13fe0a4da5
commit e9846dd758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 266 additions and 0 deletions

View File

@ -0,0 +1,49 @@
from enum import Enum
from functools import wraps
from typing import Callable, Optional
import pytorch_lightning
class TrainerState(Enum):
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
which changes state to `entering` before the function execution and `exiting`
after the function is executed. If `None` is passed to `entering`, the state is not changed.
If `None` is passed to `exiting`, the state is restored to the state before function execution.
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
"""
def wrapper(fn) -> Callable:
@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):
return fn(self, *args, **kwargs)
state_before = self.state
if entering is not None:
self.state = entering
result = fn(self, *args, **kwargs)
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
# we retain INTERRUPTED state
if self.state == TrainerState.INTERRUPTED:
return result
if exiting is not None:
self.state = exiting
else:
self.state = state_before
return result
return wrapped_fn
return wrapper

View File

@ -45,6 +45,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
@ -395,6 +396,7 @@ class Trainer(
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self.state = TrainerState.INITIALIZING
self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir
@ -888,6 +890,7 @@ class Trainer(
# -----------------------------
# MODEL TRAINING
# -----------------------------
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def fit(
self,
model: LightningModule,
@ -1240,6 +1243,7 @@ class Trainer(
self.on_sanity_check_end()
self.running_sanity_check = False
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test(
self,
model: Optional[LightningModule] = None,

View File

@ -174,6 +174,7 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -253,6 +254,7 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState
amp_type: AMPType
on_tpu: bool
@ -418,6 +420,7 @@ class TrainerTrainLoopMixin(ABC):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()
self.run_training_teardown()

View File

@ -0,0 +1,210 @@
import pytest
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from tests.base import EvalModelTemplate
class StateSnapshotCallback(Callback):
""" Allows to shapshot the state inside a particular trainer method. """
def __init__(self, snapshot_method: str):
super().__init__()
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
self.snapshot_method = snapshot_method
self.trainer_state = None
def on_batch_start(self, trainer, pl_module):
if self.snapshot_method == 'on_batch_start':
self.trainer_state = trainer.state
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if self.snapshot_method == 'on_test_batch_start':
self.trainer_state = trainer.state
def test_state_decorator_nothing_passed(tmpdir):
""" Test that state is not changed if nothing is passed to a decorator"""
@trainer_state()
def test_method(self):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.INITIALIZING
def test_state_decorator_entering_only(tmpdir):
""" Tests that state is set to entering inside a run function and restored to the previous value after. """
@trainer_state(entering=TrainerState.RUNNING)
def test_method(self):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.INITIALIZING
def test_state_decorator_exiting_only(tmpdir):
""" Tests that state is not changed inside a run function and set to `exiting` after. """
@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
assert snapshot_state == TrainerState.INITIALIZING
assert trainer.state == TrainerState.FINISHED
def test_state_decorator_entering_and_exiting(tmpdir):
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test_method(self):
return self.state
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
snapshot_state = test_method(trainer)
assert snapshot_state == TrainerState.RUNNING
assert trainer.state == TrainerState.FINISHED
def test_state_decorator_interrupt(tmpdir):
""" Tests that state remains `INTERRUPTED` is its set in run function. """
@trainer_state(exiting=TrainerState.FINISHED)
def test_method(self):
self.state = TrainerState.INTERRUPTED
trainer = Trainer(default_root_dir=tmpdir)
trainer.state = TrainerState.INITIALIZING
test_method(trainer)
assert trainer.state == TrainerState.INTERRUPTED
def test_initialize_state(tmpdir):
""" Tests that state is INITIALIZE after Trainer creation """
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.state == TrainerState.INITIALIZING
@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_running_state_during_fit(tmpdir, extra_params):
""" Tests that state is set to RUNNING during fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')
trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
**extra_params
)
trainer.fit(model)
assert snapshot_callback.trainer_state == TrainerState.RUNNING
@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_finished_state_after_fit(tmpdir, extra_params):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
trainer = Trainer(
default_root_dir=tmpdir,
**extra_params
)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED
def test_running_state_during_test(tmpdir):
""" Tests that state is set to RUNNING during test """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')
trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)
trainer.test(model)
assert snapshot_callback.trainer_state == TrainerState.RUNNING
def test_finished_state_after_test(tmpdir):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
trainer.test(model)
assert trainer.state == TrainerState.FINISHED
@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
])
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
class InterruptCallback(Callback):
def __init__(self):
super().__init__()
def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt
trainer = Trainer(
callbacks=[InterruptCallback()],
default_root_dir=tmpdir,
**extra_params
)
trainer.fit(model)
assert trainer.state == TrainerState.INTERRUPTED