Add `functools.wraps` support for `is_overridden` (#8296)

This commit is contained in:
Carlos Mocholí 2021-07-06 10:40:54 +02:00 committed by GitHub
parent 34efadd5b8
commit 8fead58273
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 6 deletions

View File

@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))
- Fixed `is_overridden` returning true for wrapped functions with no changes ([#8296](https://github.com/PyTorchLightning/pytorch-lightning/pull/8296))
- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))

View File

@ -62,10 +62,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
* **test_dataloader** the test dataloader(s).
* **teardown** (things to do on every accelerator in distributed mode when finished)
This allows you to share a full dataset without explaining how to download,
split transform and process the data
This allows you to share a full dataset without explaining how to download, split, transform, and process the data
"""
name: str = ...
@ -380,7 +377,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero
# track `DataHooks` calls
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)

View File

@ -45,6 +45,9 @@ def is_overridden(
raise ValueError("Expected a parent")
instance_attr = getattr(instance, method_name, None)
# `functools.wraps()` support
if hasattr(instance_attr, '__wrapped__'):
instance_attr = instance_attr.__wrapped__
# `Mock(wraps=...)` support
if isinstance(instance_attr, Mock):
# access the wrapped function

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from functools import partial, wraps
from unittest.mock import Mock
import pytest
@ -37,6 +37,9 @@ def test_is_overridden():
def foo(self):
pass
def bar(self):
return 1
with pytest.raises(ValueError, match="The parent should define the method"):
is_overridden("foo", TestModel())
@ -44,6 +47,30 @@ def test_is_overridden():
assert is_overridden("training_step", model)
assert is_overridden("train_dataloader", datamodule)
class WrappedModel(TestModel):
def __new__(cls, *args, **kwargs):
obj = super().__new__(cls)
obj.foo = cls.wrap(obj.foo)
obj.bar = cls.wrap(obj.bar)
return obj
@staticmethod
def wrap(fn):
@wraps(fn)
def wrapper():
fn()
return wrapper
def bar(self):
return 2
# `functools.wraps()` support
assert not is_overridden("foo", WrappedModel(), parent=TestModel)
assert is_overridden("bar", WrappedModel(), parent=TestModel)
# `Mock` support
mock = Mock(spec=BoringModel, wraps=model)
assert is_overridden("training_step", mock)