Add tracking of basic states in Trainer [wip - to-be-merged after v0.9] (#2541)
* Add initial tracking of states in Trainer. * Add INTERRUPTED state, improve tests, move state switching from callback to a trainer. * Move part of a trainer state switching to a decorator. * Add documentation. * Fix docs, rename state enum, restore state to previous on exit if None, add tests for decorator only. * Fix callback typing. Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
13fe0a4da5
commit
e9846dd758
|
@ -0,0 +1,49 @@
|
|||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytorch_lightning
|
||||
|
||||
|
||||
class TrainerState(Enum):
|
||||
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
to indicate what is currently or was executed. """
|
||||
INITIALIZING = 'INITIALIZING'
|
||||
RUNNING = 'RUNNING'
|
||||
FINISHED = 'FINISHED'
|
||||
INTERRUPTED = 'INTERRUPTED'
|
||||
|
||||
|
||||
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
|
||||
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
|
||||
which changes state to `entering` before the function execution and `exiting`
|
||||
after the function is executed. If `None` is passed to `entering`, the state is not changed.
|
||||
If `None` is passed to `exiting`, the state is restored to the state before function execution.
|
||||
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
|
||||
"""
|
||||
|
||||
def wrapper(fn) -> Callable:
|
||||
@wraps(fn)
|
||||
def wrapped_fn(self, *args, **kwargs):
|
||||
if not isinstance(self, pytorch_lightning.Trainer):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
state_before = self.state
|
||||
if entering is not None:
|
||||
self.state = entering
|
||||
result = fn(self, *args, **kwargs)
|
||||
|
||||
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
|
||||
# we retain INTERRUPTED state
|
||||
if self.state == TrainerState.INTERRUPTED:
|
||||
return result
|
||||
|
||||
if exiting is not None:
|
||||
self.state = exiting
|
||||
else:
|
||||
self.state = state_before
|
||||
return result
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
return wrapper
|
|
@ -45,6 +45,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|||
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.states import TrainerState, trainer_state
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
||||
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
||||
|
@ -395,6 +396,7 @@ class Trainer(
|
|||
self.interrupted = False
|
||||
self.should_stop = False
|
||||
self.running_sanity_check = False
|
||||
self.state = TrainerState.INITIALIZING
|
||||
|
||||
self._default_root_dir = default_root_dir or os.getcwd()
|
||||
self._weights_save_path = weights_save_path or self._default_root_dir
|
||||
|
@ -888,6 +890,7 @@ class Trainer(
|
|||
# -----------------------------
|
||||
# MODEL TRAINING
|
||||
# -----------------------------
|
||||
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
||||
def fit(
|
||||
self,
|
||||
model: LightningModule,
|
||||
|
@ -1240,6 +1243,7 @@ class Trainer(
|
|||
self.on_sanity_check_end()
|
||||
self.running_sanity_check = False
|
||||
|
||||
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
||||
def test(
|
||||
self,
|
||||
model: Optional[LightningModule] = None,
|
||||
|
|
|
@ -174,6 +174,7 @@ from pytorch_lightning.callbacks.base import Callback
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -253,6 +254,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
terminate_on_nan: bool
|
||||
tpu_id: int
|
||||
interactive_ddp_procs: ...
|
||||
state: TrainerState
|
||||
amp_type: AMPType
|
||||
on_tpu: bool
|
||||
|
||||
|
@ -418,6 +420,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# user could press ctrl+c many times... only shutdown once
|
||||
if not self.interrupted:
|
||||
self.interrupted = True
|
||||
self.state = TrainerState.INTERRUPTED
|
||||
self.on_keyboard_interrupt()
|
||||
|
||||
self.run_training_teardown()
|
||||
|
|
|
@ -0,0 +1,210 @@
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.trainer.states import TrainerState, trainer_state
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
class StateSnapshotCallback(Callback):
|
||||
""" Allows to shapshot the state inside a particular trainer method. """
|
||||
|
||||
def __init__(self, snapshot_method: str):
|
||||
super().__init__()
|
||||
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
|
||||
self.snapshot_method = snapshot_method
|
||||
self.trainer_state = None
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
if self.snapshot_method == 'on_batch_start':
|
||||
self.trainer_state = trainer.state
|
||||
|
||||
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
if self.snapshot_method == 'on_test_batch_start':
|
||||
self.trainer_state = trainer.state
|
||||
|
||||
|
||||
def test_state_decorator_nothing_passed(tmpdir):
|
||||
""" Test that state is not changed if nothing is passed to a decorator"""
|
||||
|
||||
@trainer_state()
|
||||
def test_method(self):
|
||||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
assert snapshot_state == TrainerState.INITIALIZING
|
||||
assert trainer.state == TrainerState.INITIALIZING
|
||||
|
||||
|
||||
def test_state_decorator_entering_only(tmpdir):
|
||||
""" Tests that state is set to entering inside a run function and restored to the previous value after. """
|
||||
|
||||
@trainer_state(entering=TrainerState.RUNNING)
|
||||
def test_method(self):
|
||||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
assert snapshot_state == TrainerState.RUNNING
|
||||
assert trainer.state == TrainerState.INITIALIZING
|
||||
|
||||
|
||||
def test_state_decorator_exiting_only(tmpdir):
|
||||
""" Tests that state is not changed inside a run function and set to `exiting` after. """
|
||||
|
||||
@trainer_state(exiting=TrainerState.FINISHED)
|
||||
def test_method(self):
|
||||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
assert snapshot_state == TrainerState.INITIALIZING
|
||||
assert trainer.state == TrainerState.FINISHED
|
||||
|
||||
|
||||
def test_state_decorator_entering_and_exiting(tmpdir):
|
||||
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """
|
||||
|
||||
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
||||
def test_method(self):
|
||||
return self.state
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
snapshot_state = test_method(trainer)
|
||||
|
||||
assert snapshot_state == TrainerState.RUNNING
|
||||
assert trainer.state == TrainerState.FINISHED
|
||||
|
||||
|
||||
def test_state_decorator_interrupt(tmpdir):
|
||||
""" Tests that state remains `INTERRUPTED` is its set in run function. """
|
||||
|
||||
@trainer_state(exiting=TrainerState.FINISHED)
|
||||
def test_method(self):
|
||||
self.state = TrainerState.INTERRUPTED
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
trainer.state = TrainerState.INITIALIZING
|
||||
|
||||
test_method(trainer)
|
||||
assert trainer.state == TrainerState.INTERRUPTED
|
||||
|
||||
|
||||
def test_initialize_state(tmpdir):
|
||||
""" Tests that state is INITIALIZE after Trainer creation """
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
assert trainer.state == TrainerState.INITIALIZING
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extra_params", [
|
||||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
||||
pytest.param(dict(max_steps=1), id='Single-Step'),
|
||||
])
|
||||
def test_running_state_during_fit(tmpdir, extra_params):
|
||||
""" Tests that state is set to RUNNING during fit """
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[snapshot_callback],
|
||||
default_root_dir=tmpdir,
|
||||
**extra_params
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert snapshot_callback.trainer_state == TrainerState.RUNNING
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extra_params", [
|
||||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
||||
pytest.param(dict(max_steps=1), id='Single-Step'),
|
||||
])
|
||||
def test_finished_state_after_fit(tmpdir, extra_params):
|
||||
""" Tests that state is FINISHED after fit """
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
**extra_params
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.state == TrainerState.FINISHED
|
||||
|
||||
|
||||
def test_running_state_during_test(tmpdir):
|
||||
""" Tests that state is set to RUNNING during test """
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[snapshot_callback],
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
||||
|
||||
assert snapshot_callback.trainer_state == TrainerState.RUNNING
|
||||
|
||||
|
||||
def test_finished_state_after_test(tmpdir):
|
||||
""" Tests that state is FINISHED after fit """
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
||||
|
||||
assert trainer.state == TrainerState.FINISHED
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extra_params", [
|
||||
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
||||
pytest.param(dict(max_steps=1), id='Single-Step'),
|
||||
])
|
||||
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
|
||||
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
class InterruptCallback(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[InterruptCallback()],
|
||||
default_root_dir=tmpdir,
|
||||
**extra_params
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.state == TrainerState.INTERRUPTED
|
Loading…
Reference in New Issue