Reset current_fx properties on lightning module in teardown (#7247)

* Update trainer.py

* cleanup module properties in teardown

* Update test_trainer.py

* Update lightning.py

* Formatting

* flake8

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
ananthsub 2021-04-28 12:17:20 -07:00 committed by GitHub
parent 40f80230fe
commit 075de9356c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 6 deletions

View File

@ -92,19 +92,19 @@ class LightningModule(
self._device_type = None
#: True if using amp
self.use_amp = False
self.use_amp: bool = False
#: The precision used
self.precision = 32
self.precision: int = 32
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self._current_fx_name: str = ''
self._running_manual_backward: bool = False
self._current_hook_fx_name: Optional[str] = None
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._param_requires_grad_state = dict()

View File

@ -1147,6 +1147,10 @@ class Trainer(
self.teardown(stage=state)
model.teardown(stage=state)
model._current_fx_name = ""
model._current_hook_fx_name = None
model._current_dataloader_idx = None
def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool:
# on_before_zero_grad is called within training_step
if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"):

View File

@ -2041,3 +2041,33 @@ def test_fit_test_synchronization(tmpdir):
trainer.fit(model)
assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}'
trainer.test()
def test_module_current_fx_attributes_reset(tmpdir):
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
model = BoringModel()
model.validation_step = None
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=False,
logger=False,
limit_val_batches=0,
)
trainer.fit(model)
assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}"
assert (
model._current_hook_fx_name is None
), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}"
assert (
model._current_dataloader_idx is None
), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}"
trainer.test(model)
assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}"
assert (
model._current_hook_fx_name is None
), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}"
assert (
model._current_dataloader_idx is None
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"