Add `functools.wraps` support for `is_overridden` (#8296)
This commit is contained in:
parent
34efadd5b8
commit
8fead58273
|
@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))
|
||||
|
||||
|
||||
- Fixed `is_overridden` returning true for wrapped functions with no changes ([#8296](https://github.com/PyTorchLightning/pytorch-lightning/pull/8296))
|
||||
|
||||
|
||||
- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))
|
||||
|
||||
|
||||
|
|
|
@ -62,10 +62,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
* **test_dataloader** the test dataloader(s).
|
||||
* **teardown** (things to do on every accelerator in distributed mode when finished)
|
||||
|
||||
|
||||
This allows you to share a full dataset without explaining how to download,
|
||||
split transform and process the data
|
||||
|
||||
This allows you to share a full dataset without explaining how to download, split, transform, and process the data
|
||||
"""
|
||||
|
||||
name: str = ...
|
||||
|
@ -380,7 +377,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
|
||||
obj = super().__new__(cls)
|
||||
# track `DataHooks` calls and run `prepare_data` only on rank zero
|
||||
# track `DataHooks` calls
|
||||
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
|
||||
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
|
||||
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
|
||||
|
|
|
@ -45,6 +45,9 @@ def is_overridden(
|
|||
raise ValueError("Expected a parent")
|
||||
|
||||
instance_attr = getattr(instance, method_name, None)
|
||||
# `functools.wraps()` support
|
||||
if hasattr(instance_attr, '__wrapped__'):
|
||||
instance_attr = instance_attr.__wrapped__
|
||||
# `Mock(wraps=...)` support
|
||||
if isinstance(instance_attr, Mock):
|
||||
# access the wrapped function
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
@ -37,6 +37,9 @@ def test_is_overridden():
|
|||
def foo(self):
|
||||
pass
|
||||
|
||||
def bar(self):
|
||||
return 1
|
||||
|
||||
with pytest.raises(ValueError, match="The parent should define the method"):
|
||||
is_overridden("foo", TestModel())
|
||||
|
||||
|
@ -44,6 +47,30 @@ def test_is_overridden():
|
|||
assert is_overridden("training_step", model)
|
||||
assert is_overridden("train_dataloader", datamodule)
|
||||
|
||||
class WrappedModel(TestModel):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
obj = super().__new__(cls)
|
||||
obj.foo = cls.wrap(obj.foo)
|
||||
obj.bar = cls.wrap(obj.bar)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def wrap(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper():
|
||||
fn()
|
||||
|
||||
return wrapper
|
||||
|
||||
def bar(self):
|
||||
return 2
|
||||
|
||||
# `functools.wraps()` support
|
||||
assert not is_overridden("foo", WrappedModel(), parent=TestModel)
|
||||
assert is_overridden("bar", WrappedModel(), parent=TestModel)
|
||||
|
||||
# `Mock` support
|
||||
mock = Mock(spec=BoringModel, wraps=model)
|
||||
assert is_overridden("training_step", mock)
|
||||
|
|
Loading…
Reference in New Issue