From ac4eb0a06a82cb46385ef3aa9ba86b785f260312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 11 Jun 2021 13:47:00 +0200 Subject: [PATCH] `is_overridden` improvements (#7918) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 6 ++ pytorch_lightning/trainer/evaluation_loop.py | 8 +-- pytorch_lightning/trainer/training_loop.py | 6 +- pytorch_lightning/utilities/model_helpers.py | 68 ++++++++++++++------ tests/deprecated_api/test_remove_1-6.py | 9 +++ tests/trainer/test_states.py | 15 +++-- tests/trainer/test_trainer.py | 63 ++++++++++++++---- tests/utilities/test_model_helpers.py | 67 +++++++++++++++++++ 8 files changed, 195 insertions(+), 47 deletions(-) create mode 100644 tests/utilities/test_model_helpers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index de24cc9daa..2a3503a94e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) +- Added support for passing any class to `is_overridden` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918)) + + - Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195)) @@ -172,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) +- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918)) + + - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f7dc100d78..ef13964235 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -201,9 +201,9 @@ class EvaluationLoop(object): def _should_track_batch_outputs_for_epoch_end(self) -> bool: model = self.trainer.lightning_module if self.trainer.testing: - return is_overridden('test_epoch_end', model=model) + return is_overridden('test_epoch_end', model) else: - return is_overridden('validation_epoch_end', model=model) + return is_overridden('validation_epoch_end', model) def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # inform logger the batch loop has finished @@ -216,12 +216,12 @@ class EvaluationLoop(object): model._current_dataloader_idx = None if self.trainer.testing: - if is_overridden('test_epoch_end', model=model): + if is_overridden('test_epoch_end', model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: - if is_overridden('validation_epoch_end', model=model): + if is_overridden('validation_epoch_end', model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ddab8f837b..f76568454b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -216,10 +216,10 @@ class TrainLoop: # 2. The model overrides on_train_epoch_end which has `outputs` in the signature # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module - if is_overridden("training_epoch_end", model=lightning_module): + if is_overridden("training_epoch_end", lightning_module): return True - if is_overridden("on_train_epoch_end", model=lightning_module): + if is_overridden("on_train_epoch_end", lightning_module): model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True @@ -540,7 +540,7 @@ class TrainLoop: # get the model and call model.training_epoch_end model = self.trainer.lightning_module - if is_overridden('training_epoch_end', model=model): + if is_overridden('training_epoch_end', model): # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 87bd9e6c45..b7c3c09aff 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -11,33 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Union +from functools import partial +from typing import Optional, Type, Union +from unittest.mock import Mock from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_deprecation -def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: - # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super - # TODO - refector this function to accept model_name, instance, parent so it makes more sense - super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule +def is_overridden( + method_name: str, + instance: Optional[object] = None, + parent: Optional[Type[object]] = None, + model: Optional[Union[LightningModule, LightningDataModule]] = None, +) -> bool: + if model is not None and instance is None: + rank_zero_deprecation( + '`is_overriden(model=...)` has been deprecated and will be removed in v1.6.' + 'Please use `is_overriden(instance=...)`' + ) + instance = model - if not hasattr(model, method_name) or not hasattr(super_object, method_name): - # in case of calling deprecated method + if instance is None: + # if `self.lightning_module` was passed as instance, it can be `None` return False - instance_attr = getattr(model, method_name) - if not instance_attr: - return False - super_attr = getattr(super_object, method_name) + if parent is None: + if isinstance(instance, LightningModule): + parent = LightningModule + elif isinstance(instance, LightningDataModule): + parent = LightningDataModule + if parent is None: + raise ValueError("Expected a parent") - # when code pointers are different, it was implemented - if hasattr(instance_attr, 'patch_loader_code'): - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__) - else: - is_overridden = instance_attr.__code__ is not super_attr.__code__ - return is_overridden + instance_attr = getattr(instance, method_name, None) + # `Mock(wraps=...)` support + if isinstance(instance_attr, Mock): + # access the wrapped function + instance_attr = instance_attr._mock_wraps + # `partial` support + elif isinstance(instance_attr, partial): + instance_attr = instance_attr.func + if instance_attr is None: + return False + + parent_attr = getattr(parent, method_name, None) + if parent_attr is None: + raise ValueError("The parent should define the method") + + # cannot pickle `__code__` so cannot verify if `PatchDataloader` + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + instance_code = getattr(instance_attr, 'patch_loader_code', None) or instance_attr.__code__ + parent_code = parent_attr.__code__ + + return instance_code != parent_code diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 1dfbb91022..ced066381a 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -17,6 +17,7 @@ import pytest from pytorch_lightning import Trainer from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin +from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel @@ -175,6 +176,14 @@ def test_v1_6_0_datamodule_hooks_calls(tmpdir): assert dm.teardown_calls == ['validate', 'test'] +def test_v1_6_0_is_overridden_model(): + model = BoringModel() + with pytest.deprecated_call(match="and will be removed in v1.6"): + assert is_overridden("validation_step", model=model) + with pytest.deprecated_call(match="and will be removed in v1.6"): + assert not is_overridden("foo", model=model) + + def test_v1_6_0_early_stopping_monitor(tmpdir): with pytest.deprecated_call( match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6." diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index c9fb50e850..2614eda6d4 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -37,25 +37,28 @@ def test_trainer_fn_while_running(tmpdir, extra_params): def __init__(self, expected_fn, expected_stage): super().__init__() - self.expected_state = expected_fn + self.expected_fn = expected_fn self.expected_stage = expected_stage self.lr = 0.1 - def on_batch_start(self, *_): - assert self.trainer.state == TrainerState( - status=TrainerStatus.RUNNING, fn=self.expected_fn, stage=self.expected_stage - ) - def on_train_batch_start(self, *_): + assert self.trainer.state.status == TrainerStatus.RUNNING + assert self.trainer.state.fn == self.expected_fn assert self.trainer.training def on_sanity_check_start(self, *_): + assert self.trainer.state.status == TrainerStatus.RUNNING + assert self.trainer.state.fn == self.expected_fn assert self.trainer.sanity_checking def on_validation_batch_start(self, *_): + assert self.trainer.state.status == TrainerStatus.RUNNING + assert self.trainer.state.fn == self.expected_fn assert self.trainer.validating or self.trainer.sanity_checking def on_test_batch_start(self, *_): + assert self.trainer.state.status == TrainerStatus.RUNNING + assert self.trainer.state.fn == self.expected_fn assert self.trainer.testing model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 76e98329d2..d353c0941d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -232,29 +232,66 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches): """ Verify optimizer.step() applied to last batch while grad accumulation """ - class CurrentModel(BoringModel): + class TestModel(BoringModel): - def on_batch_start(self, *_): - self.on_train_batch_start_state_dict = self.state_dict() + def state_dict(self, *args, **kwargs): + return deepcopy(super().state_dict(*args, **kwargs)) - def on_batch_end(self, outputs, batch, batch_idx, *_): - self.on_train_batch_start_end_dict = self.state_dict() - for key in self.on_train_batch_start_end_dict.keys(): - equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]) - if (batch_idx + 1) == self.trainer.num_training_batches: - assert equal - else: - assert not equal + def check(self, d1, d2, equal=True): + keys = d1.keys() | d2.keys() + values = [torch.equal(d1[k], d2[k]) for k in keys] + return all(values) if equal else not any(values) - model = CurrentModel() + def backward(self, *args, **kwargs) -> None: + pre_bwd_state_dict = self.state_dict() + assert self.check(self.start_state_dict, pre_bwd_state_dict) + out = super().backward(*args, **kwargs) + + # state dict is equal, just the gradients changed + assert self.check(pre_bwd_state_dict, self.state_dict()) + + return out + + # def optimizer_step(self, *args, **kwargs): + # pre_opt_step_state_dict = self.state_dict() + # assert self.check(self.start_state_dict, pre_opt_step_state_dict) + + # # this calls `backward` and `on_after_backward` inside the closure + # out = super().optimizer_step(*args, **kwargs) + + # # the state dict changed + # assert self.check(pre_opt_step_state_dict, self.state_dict(), equal=False) + + # self.opt_step_called = True + # return out + + def on_after_backward(self): + # should override `optimizer_step` instead but can't with `accumulate_grad_batches` + # replace with the above after https://github.com/PyTorchLightning/pytorch-lightning/issues/6910 + self.opt_step_called = True + + def on_train_batch_start(self, *_): + self.start_state_dict = self.state_dict() + self.opt_step_called = False + + def on_train_batch_end(self, outputs, batch, batch_idx, *_): + end_state_dict = self.state_dict() + is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches + + if is_last_batch or self.opt_step_called: + assert self.check(self.start_state_dict, end_state_dict, equal=False) + else: + assert self.check(self.start_state_dict, end_state_dict) + + model = TestModel() trainer = Trainer( accumulate_grad_batches=accumulate_grad_batches, max_epochs=2, limit_train_batches=limit_train_batches, limit_val_batches=0, - limit_test_batches=0, default_root_dir=tmpdir, + progress_bar_refresh_rate=0, ) trainer.fit(model) diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py new file mode 100644 index 0000000000..f63d46bdb6 --- /dev/null +++ b/tests/utilities/test_model_helpers.py @@ -0,0 +1,67 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from unittest.mock import Mock + +import pytest + +from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning.utilities.model_helpers import is_overridden +from tests.helpers import BoringDataModule, BoringModel + + +def test_is_overridden(): + model = BoringModel() + datamodule = BoringDataModule() + + # edge cases + assert not is_overridden("whatever", None) + with pytest.raises(ValueError, match="Expected a parent"): + is_overridden("whatever", object()) + assert not is_overridden("whatever", model) + assert not is_overridden("whatever", model, parent=LightningDataModule) + + class TestModel(BoringModel): + + def foo(self): + pass + + with pytest.raises(ValueError, match="The parent should define the method"): + is_overridden("foo", TestModel()) + + # normal usage + assert is_overridden("training_step", model) + assert is_overridden("train_dataloader", datamodule) + + # `Mock` support + mock = Mock(spec=BoringModel, wraps=model) + assert is_overridden("training_step", mock) + mock = Mock(spec=BoringDataModule, wraps=datamodule) + assert is_overridden("train_dataloader", mock) + + # `partial` support + model.training_step = partial(model.training_step) + assert is_overridden("training_step", model) + + # `_PatchDataLoader.patch_loader_code` support + class TestModel(BoringModel): + + def on_fit_start(self): + assert is_overridden("train_dataloader", self) + self.on_fit_start_called = True + + model = TestModel() + trainer = Trainer(fast_dev_run=1) + trainer.fit(model, train_dataloader=model.train_dataloader()) + assert model.on_fit_start_called