Use `update_wrapper` in test_hooks.py (#10578)

This commit is contained in:
Carlos Mocholí 2021-11-19 01:52:55 +01:00 committed by GitHub
parent 700521c7d3
commit 35f6cbe09f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 15 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from functools import partial, update_wrapper
from inspect import getmembers, isfunction
from unittest import mock
from unittest.mock import ANY, PropertyMock
@ -223,7 +223,9 @@ class HookedCallback(Callback):
for h in get_members(Callback):
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
partial_h = partial(call, h, attr)
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)
def on_save_checkpoint(*args, **kwargs):
return {"foo": True}
@ -256,7 +258,9 @@ class HookedModel(BoringModel):
for h in pl_module_hooks:
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
partial_h = partial(call, h, attr)
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)
def validation_epoch_end(self, *args, **kwargs):
# `BoringModel` does not have a return for `validation_step_end` so this would fail
@ -852,7 +856,9 @@ def test_trainer_datamodule_hook_system(tmpdir):
for h in get_members(LightningDataModule):
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
partial_h = partial(call, h, attr)
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)
model = BoringModel()
batches = 2
@ -871,20 +877,12 @@ def test_trainer_datamodule_hook_system(tmpdir):
called = []
dm = HookedDataModule(called)
trainer.fit(model, datamodule=dm)
batch_transfer = [
dict(name="on_before_batch_transfer", args=(ANY, 0)),
dict(name="transfer_batch_to_device", args=(ANY, torch.device("cpu"), 0)),
dict(name="on_after_batch_transfer", args=(ANY, 0)),
]
expected = [
dict(name="prepare_data"),
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="val_dataloader"),
*batch_transfer * batches,
dict(name="train_dataloader"),
*batch_transfer * batches,
dict(name="val_dataloader"),
*batch_transfer * batches,
dict(
name="on_save_checkpoint",
args=(
@ -910,7 +908,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
dict(name="prepare_data"),
dict(name="setup", kwargs=dict(stage="validate")),
dict(name="val_dataloader"),
*batch_transfer * batches,
dict(name="teardown", kwargs=dict(stage="validate")),
]
assert called == expected
@ -922,7 +919,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
dict(name="prepare_data"),
dict(name="setup", kwargs=dict(stage="test")),
dict(name="test_dataloader"),
*batch_transfer * batches,
dict(name="teardown", kwargs=dict(stage="test")),
]
assert called == expected
@ -934,7 +930,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
dict(name="prepare_data"),
dict(name="setup", kwargs=dict(stage="predict")),
dict(name="predict_dataloader"),
*batch_transfer * batches,
dict(name="teardown", kwargs=dict(stage="predict")),
]
assert called == expected