Reset current_fx properties on lightning module in teardown (#7247)
* Update trainer.py * cleanup module properties in teardown * Update test_trainer.py * Update lightning.py * Formatting * flake8 * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
40f80230fe
commit
075de9356c
|
@ -92,19 +92,19 @@ class LightningModule(
|
|||
self._device_type = None
|
||||
|
||||
#: True if using amp
|
||||
self.use_amp = False
|
||||
self.use_amp: bool = False
|
||||
|
||||
#: The precision used
|
||||
self.precision = 32
|
||||
self.precision: int = 32
|
||||
|
||||
# optionally can be set by user
|
||||
self._example_input_array = None
|
||||
self._datamodule = None
|
||||
self._results: Optional[Result] = None
|
||||
self._current_fx_name = ''
|
||||
self._running_manual_backward = False
|
||||
self._current_hook_fx_name = None
|
||||
self._current_dataloader_idx = None
|
||||
self._current_fx_name: str = ''
|
||||
self._running_manual_backward: bool = False
|
||||
self._current_hook_fx_name: Optional[str] = None
|
||||
self._current_dataloader_idx: Optional[int] = None
|
||||
self._automatic_optimization: bool = True
|
||||
self._param_requires_grad_state = dict()
|
||||
|
||||
|
|
|
@ -1147,6 +1147,10 @@ class Trainer(
|
|||
self.teardown(stage=state)
|
||||
model.teardown(stage=state)
|
||||
|
||||
model._current_fx_name = ""
|
||||
model._current_hook_fx_name = None
|
||||
model._current_dataloader_idx = None
|
||||
|
||||
def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool:
|
||||
# on_before_zero_grad is called within training_step
|
||||
if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"):
|
||||
|
|
|
@ -2041,3 +2041,33 @@ def test_fit_test_synchronization(tmpdir):
|
|||
trainer.fit(model)
|
||||
assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}'
|
||||
trainer.test()
|
||||
|
||||
|
||||
def test_module_current_fx_attributes_reset(tmpdir):
|
||||
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
|
||||
model = BoringModel()
|
||||
model.validation_step = None
|
||||
model.training_epoch_end = None
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
checkpoint_callback=False,
|
||||
logger=False,
|
||||
limit_val_batches=0,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}"
|
||||
assert (
|
||||
model._current_hook_fx_name is None
|
||||
), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}"
|
||||
assert (
|
||||
model._current_dataloader_idx is None
|
||||
), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}"
|
||||
trainer.test(model)
|
||||
assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}"
|
||||
assert (
|
||||
model._current_hook_fx_name is None
|
||||
), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}"
|
||||
assert (
|
||||
model._current_dataloader_idx is None
|
||||
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"
|
||||
|
|
Loading…
Reference in New Issue