Use `update_wrapper` in test_hooks.py (#10578)
This commit is contained in:
parent
700521c7d3
commit
35f6cbe09f
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from functools import partial, update_wrapper
|
||||
from inspect import getmembers, isfunction
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, PropertyMock
|
||||
|
@ -223,7 +223,9 @@ class HookedCallback(Callback):
|
|||
|
||||
for h in get_members(Callback):
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
partial_h = partial(call, h, attr)
|
||||
update_wrapper(partial_h, attr)
|
||||
setattr(self, h, partial_h)
|
||||
|
||||
def on_save_checkpoint(*args, **kwargs):
|
||||
return {"foo": True}
|
||||
|
@ -256,7 +258,9 @@ class HookedModel(BoringModel):
|
|||
|
||||
for h in pl_module_hooks:
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
partial_h = partial(call, h, attr)
|
||||
update_wrapper(partial_h, attr)
|
||||
setattr(self, h, partial_h)
|
||||
|
||||
def validation_epoch_end(self, *args, **kwargs):
|
||||
# `BoringModel` does not have a return for `validation_step_end` so this would fail
|
||||
|
@ -852,7 +856,9 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
|
||||
for h in get_members(LightningDataModule):
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
partial_h = partial(call, h, attr)
|
||||
update_wrapper(partial_h, attr)
|
||||
setattr(self, h, partial_h)
|
||||
|
||||
model = BoringModel()
|
||||
batches = 2
|
||||
|
@ -871,20 +877,12 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
called = []
|
||||
dm = HookedDataModule(called)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
batch_transfer = [
|
||||
dict(name="on_before_batch_transfer", args=(ANY, 0)),
|
||||
dict(name="transfer_batch_to_device", args=(ANY, torch.device("cpu"), 0)),
|
||||
dict(name="on_after_batch_transfer", args=(ANY, 0)),
|
||||
]
|
||||
expected = [
|
||||
dict(name="prepare_data"),
|
||||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="val_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(name="train_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(name="val_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(
|
||||
name="on_save_checkpoint",
|
||||
args=(
|
||||
|
@ -910,7 +908,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
dict(name="prepare_data"),
|
||||
dict(name="setup", kwargs=dict(stage="validate")),
|
||||
dict(name="val_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(name="teardown", kwargs=dict(stage="validate")),
|
||||
]
|
||||
assert called == expected
|
||||
|
@ -922,7 +919,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
dict(name="prepare_data"),
|
||||
dict(name="setup", kwargs=dict(stage="test")),
|
||||
dict(name="test_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(name="teardown", kwargs=dict(stage="test")),
|
||||
]
|
||||
assert called == expected
|
||||
|
@ -934,7 +930,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
dict(name="prepare_data"),
|
||||
dict(name="setup", kwargs=dict(stage="predict")),
|
||||
dict(name="predict_dataloader"),
|
||||
*batch_transfer * batches,
|
||||
dict(name="teardown", kwargs=dict(stage="predict")),
|
||||
]
|
||||
assert called == expected
|
||||
|
|
Loading…
Reference in New Issue