206 lines
6.0 KiB
Python
206 lines
6.0 KiB
Python
import pytest
|
|
|
|
from pytorch_lightning import Trainer, Callback
|
|
from pytorch_lightning.trainer.states import TrainerState, trainer_state
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
class StateSnapshotCallback(Callback):
|
|
""" Allows to shapshot the state inside a particular trainer method. """
|
|
|
|
def __init__(self, snapshot_method: str):
|
|
super().__init__()
|
|
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
|
|
self.snapshot_method = snapshot_method
|
|
self.trainer_state = None
|
|
|
|
def on_batch_start(self, trainer, pl_module):
|
|
if self.snapshot_method == 'on_batch_start':
|
|
self.trainer_state = trainer.state
|
|
|
|
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
|
if self.snapshot_method == 'on_test_batch_start':
|
|
self.trainer_state = trainer.state
|
|
|
|
|
|
def test_state_decorator_nothing_passed(tmpdir):
|
|
""" Test that state is not changed if nothing is passed to a decorator"""
|
|
|
|
@trainer_state()
|
|
def test_method(self):
|
|
return self.state
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
|
|
snapshot_state = test_method(trainer)
|
|
|
|
assert snapshot_state == TrainerState.INITIALIZING
|
|
assert trainer.state == TrainerState.INITIALIZING
|
|
|
|
|
|
def test_state_decorator_entering_only(tmpdir):
|
|
""" Tests that state is set to entering inside a run function and restored to the previous value after. """
|
|
|
|
@trainer_state(entering=TrainerState.RUNNING)
|
|
def test_method(self):
|
|
return self.state
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
|
|
snapshot_state = test_method(trainer)
|
|
|
|
assert snapshot_state == TrainerState.RUNNING
|
|
assert trainer.state == TrainerState.INITIALIZING
|
|
|
|
|
|
def test_state_decorator_exiting_only(tmpdir):
|
|
""" Tests that state is not changed inside a run function and set to `exiting` after. """
|
|
|
|
@trainer_state(exiting=TrainerState.FINISHED)
|
|
def test_method(self):
|
|
return self.state
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
|
|
snapshot_state = test_method(trainer)
|
|
|
|
assert snapshot_state == TrainerState.INITIALIZING
|
|
assert trainer.state == TrainerState.FINISHED
|
|
|
|
|
|
def test_state_decorator_entering_and_exiting(tmpdir):
|
|
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """
|
|
|
|
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
|
def test_method(self):
|
|
return self.state
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
|
|
snapshot_state = test_method(trainer)
|
|
|
|
assert snapshot_state == TrainerState.RUNNING
|
|
assert trainer.state == TrainerState.FINISHED
|
|
|
|
|
|
def test_state_decorator_interrupt(tmpdir):
|
|
""" Tests that state remains `INTERRUPTED` is its set in run function. """
|
|
|
|
@trainer_state(exiting=TrainerState.FINISHED)
|
|
def test_method(self):
|
|
self._state = TrainerState.INTERRUPTED
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
|
|
test_method(trainer)
|
|
assert trainer.state == TrainerState.INTERRUPTED
|
|
|
|
|
|
def test_initialize_state(tmpdir):
|
|
""" Tests that state is INITIALIZE after Trainer creation """
|
|
trainer = Trainer(default_root_dir=tmpdir)
|
|
assert trainer.state == TrainerState.INITIALIZING
|
|
|
|
|
|
@pytest.mark.parametrize("extra_params", [
|
|
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
|
pytest.param(dict(max_steps=1), id='Single-Step'),
|
|
])
|
|
def test_running_state_during_fit(tmpdir, extra_params):
|
|
""" Tests that state is set to RUNNING during fit """
|
|
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
|
|
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')
|
|
|
|
trainer = Trainer(
|
|
callbacks=[snapshot_callback],
|
|
default_root_dir=tmpdir,
|
|
**extra_params
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert snapshot_callback.trainer_state == TrainerState.RUNNING
|
|
|
|
|
|
@pytest.mark.parametrize("extra_params", [
|
|
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
|
pytest.param(dict(max_steps=1), id='Single-Step'),
|
|
])
|
|
def test_finished_state_after_fit(tmpdir, extra_params):
|
|
""" Tests that state is FINISHED after fit """
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
**extra_params
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert trainer.state == TrainerState.FINISHED
|
|
|
|
|
|
def test_running_state_during_test(tmpdir):
|
|
""" Tests that state is set to RUNNING during test """
|
|
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
|
|
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')
|
|
|
|
trainer = Trainer(
|
|
callbacks=[snapshot_callback],
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=True,
|
|
)
|
|
|
|
trainer.test(model)
|
|
|
|
assert snapshot_callback.trainer_state == TrainerState.RUNNING
|
|
|
|
|
|
def test_finished_state_after_test(tmpdir):
|
|
""" Tests that state is FINISHED after fit """
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=True,
|
|
)
|
|
|
|
trainer.test(model)
|
|
|
|
assert trainer.state == TrainerState.FINISHED
|
|
|
|
|
|
@pytest.mark.parametrize("extra_params", [
|
|
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
|
|
pytest.param(dict(max_steps=1), id='Single-Step'),
|
|
])
|
|
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
|
|
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
|
|
class InterruptCallback(Callback):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def on_batch_start(self, trainer, pl_module):
|
|
raise KeyboardInterrupt
|
|
|
|
trainer = Trainer(
|
|
callbacks=[InterruptCallback()],
|
|
default_root_dir=tmpdir,
|
|
**extra_params
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert trainer.state == TrainerState.INTERRUPTED
|