`is_overridden` improvements (#7918)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
9e932f4dfd
commit
ac4eb0a06a
|
@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
|
||||
|
||||
- Added support for passing any class to `is_overridden` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
|
||||
|
||||
|
||||
- Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195))
|
||||
|
||||
|
||||
|
@ -172,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
|
||||
|
||||
|
||||
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
|
||||
|
||||
|
||||
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
|
||||
|
||||
|
||||
|
|
|
@ -201,9 +201,9 @@ class EvaluationLoop(object):
|
|||
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
|
||||
model = self.trainer.lightning_module
|
||||
if self.trainer.testing:
|
||||
return is_overridden('test_epoch_end', model=model)
|
||||
return is_overridden('test_epoch_end', model)
|
||||
else:
|
||||
return is_overridden('validation_epoch_end', model=model)
|
||||
return is_overridden('validation_epoch_end', model)
|
||||
|
||||
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
||||
# inform logger the batch loop has finished
|
||||
|
@ -216,12 +216,12 @@ class EvaluationLoop(object):
|
|||
model._current_dataloader_idx = None
|
||||
|
||||
if self.trainer.testing:
|
||||
if is_overridden('test_epoch_end', model=model):
|
||||
if is_overridden('test_epoch_end', model):
|
||||
model._current_fx_name = 'test_epoch_end'
|
||||
model.test_epoch_end(outputs)
|
||||
|
||||
else:
|
||||
if is_overridden('validation_epoch_end', model=model):
|
||||
if is_overridden('validation_epoch_end', model):
|
||||
model._current_fx_name = 'validation_epoch_end'
|
||||
model.validation_epoch_end(outputs)
|
||||
|
||||
|
|
|
@ -216,10 +216,10 @@ class TrainLoop:
|
|||
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
|
||||
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
|
||||
lightning_module = self.trainer.lightning_module
|
||||
if is_overridden("training_epoch_end", model=lightning_module):
|
||||
if is_overridden("training_epoch_end", lightning_module):
|
||||
return True
|
||||
|
||||
if is_overridden("on_train_epoch_end", model=lightning_module):
|
||||
if is_overridden("on_train_epoch_end", lightning_module):
|
||||
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
|
||||
if is_param_in_hook_signature(model_hook_fx, "outputs"):
|
||||
return True
|
||||
|
@ -540,7 +540,7 @@ class TrainLoop:
|
|||
# get the model and call model.training_epoch_end
|
||||
model = self.trainer.lightning_module
|
||||
|
||||
if is_overridden('training_epoch_end', model=model):
|
||||
if is_overridden('training_epoch_end', model):
|
||||
# run training_epoch_end
|
||||
# refresh the result for custom logging at the epoch level
|
||||
model._current_fx_name = 'training_epoch_end'
|
||||
|
|
|
@ -11,33 +11,59 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
from functools import partial
|
||||
from typing import Optional, Type, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
|
||||
|
||||
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
|
||||
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
|
||||
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
|
||||
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
|
||||
def is_overridden(
|
||||
method_name: str,
|
||||
instance: Optional[object] = None,
|
||||
parent: Optional[Type[object]] = None,
|
||||
model: Optional[Union[LightningModule, LightningDataModule]] = None,
|
||||
) -> bool:
|
||||
if model is not None and instance is None:
|
||||
rank_zero_deprecation(
|
||||
'`is_overriden(model=...)` has been deprecated and will be removed in v1.6.'
|
||||
'Please use `is_overriden(instance=...)`'
|
||||
)
|
||||
instance = model
|
||||
|
||||
if not hasattr(model, method_name) or not hasattr(super_object, method_name):
|
||||
# in case of calling deprecated method
|
||||
if instance is None:
|
||||
# if `self.lightning_module` was passed as instance, it can be `None`
|
||||
return False
|
||||
|
||||
instance_attr = getattr(model, method_name)
|
||||
if not instance_attr:
|
||||
return False
|
||||
super_attr = getattr(super_object, method_name)
|
||||
if parent is None:
|
||||
if isinstance(instance, LightningModule):
|
||||
parent = LightningModule
|
||||
elif isinstance(instance, LightningDataModule):
|
||||
parent = LightningDataModule
|
||||
if parent is None:
|
||||
raise ValueError("Expected a parent")
|
||||
|
||||
# when code pointers are different, it was implemented
|
||||
if hasattr(instance_attr, 'patch_loader_code'):
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
|
||||
else:
|
||||
is_overridden = instance_attr.__code__ is not super_attr.__code__
|
||||
return is_overridden
|
||||
instance_attr = getattr(instance, method_name, None)
|
||||
# `Mock(wraps=...)` support
|
||||
if isinstance(instance_attr, Mock):
|
||||
# access the wrapped function
|
||||
instance_attr = instance_attr._mock_wraps
|
||||
# `partial` support
|
||||
elif isinstance(instance_attr, partial):
|
||||
instance_attr = instance_attr.func
|
||||
if instance_attr is None:
|
||||
return False
|
||||
|
||||
parent_attr = getattr(parent, method_name, None)
|
||||
if parent_attr is None:
|
||||
raise ValueError("The parent should define the method")
|
||||
|
||||
# cannot pickle `__code__` so cannot verify if `PatchDataloader`
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
instance_code = getattr(instance_attr, 'patch_loader_code', None) or instance_attr.__code__
|
||||
parent_code = parent_attr.__code__
|
||||
|
||||
return instance_code != parent_code
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
|
@ -175,6 +176,14 @@ def test_v1_6_0_datamodule_hooks_calls(tmpdir):
|
|||
assert dm.teardown_calls == ['validate', 'test']
|
||||
|
||||
|
||||
def test_v1_6_0_is_overridden_model():
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="and will be removed in v1.6"):
|
||||
assert is_overridden("validation_step", model=model)
|
||||
with pytest.deprecated_call(match="and will be removed in v1.6"):
|
||||
assert not is_overridden("foo", model=model)
|
||||
|
||||
|
||||
def test_v1_6_0_early_stopping_monitor(tmpdir):
|
||||
with pytest.deprecated_call(
|
||||
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."
|
||||
|
|
|
@ -37,25 +37,28 @@ def test_trainer_fn_while_running(tmpdir, extra_params):
|
|||
|
||||
def __init__(self, expected_fn, expected_stage):
|
||||
super().__init__()
|
||||
self.expected_state = expected_fn
|
||||
self.expected_fn = expected_fn
|
||||
self.expected_stage = expected_stage
|
||||
self.lr = 0.1
|
||||
|
||||
def on_batch_start(self, *_):
|
||||
assert self.trainer.state == TrainerState(
|
||||
status=TrainerStatus.RUNNING, fn=self.expected_fn, stage=self.expected_stage
|
||||
)
|
||||
|
||||
def on_train_batch_start(self, *_):
|
||||
assert self.trainer.state.status == TrainerStatus.RUNNING
|
||||
assert self.trainer.state.fn == self.expected_fn
|
||||
assert self.trainer.training
|
||||
|
||||
def on_sanity_check_start(self, *_):
|
||||
assert self.trainer.state.status == TrainerStatus.RUNNING
|
||||
assert self.trainer.state.fn == self.expected_fn
|
||||
assert self.trainer.sanity_checking
|
||||
|
||||
def on_validation_batch_start(self, *_):
|
||||
assert self.trainer.state.status == TrainerStatus.RUNNING
|
||||
assert self.trainer.state.fn == self.expected_fn
|
||||
assert self.trainer.validating or self.trainer.sanity_checking
|
||||
|
||||
def on_test_batch_start(self, *_):
|
||||
assert self.trainer.state.status == TrainerStatus.RUNNING
|
||||
assert self.trainer.state.fn == self.expected_fn
|
||||
assert self.trainer.testing
|
||||
|
||||
model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING)
|
||||
|
|
|
@ -232,29 +232,66 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
|
|||
def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches):
|
||||
""" Verify optimizer.step() applied to last batch while grad accumulation """
|
||||
|
||||
class CurrentModel(BoringModel):
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def on_batch_start(self, *_):
|
||||
self.on_train_batch_start_state_dict = self.state_dict()
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return deepcopy(super().state_dict(*args, **kwargs))
|
||||
|
||||
def on_batch_end(self, outputs, batch, batch_idx, *_):
|
||||
self.on_train_batch_start_end_dict = self.state_dict()
|
||||
for key in self.on_train_batch_start_end_dict.keys():
|
||||
equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
|
||||
if (batch_idx + 1) == self.trainer.num_training_batches:
|
||||
assert equal
|
||||
else:
|
||||
assert not equal
|
||||
def check(self, d1, d2, equal=True):
|
||||
keys = d1.keys() | d2.keys()
|
||||
values = [torch.equal(d1[k], d2[k]) for k in keys]
|
||||
return all(values) if equal else not any(values)
|
||||
|
||||
model = CurrentModel()
|
||||
def backward(self, *args, **kwargs) -> None:
|
||||
pre_bwd_state_dict = self.state_dict()
|
||||
assert self.check(self.start_state_dict, pre_bwd_state_dict)
|
||||
|
||||
out = super().backward(*args, **kwargs)
|
||||
|
||||
# state dict is equal, just the gradients changed
|
||||
assert self.check(pre_bwd_state_dict, self.state_dict())
|
||||
|
||||
return out
|
||||
|
||||
# def optimizer_step(self, *args, **kwargs):
|
||||
# pre_opt_step_state_dict = self.state_dict()
|
||||
# assert self.check(self.start_state_dict, pre_opt_step_state_dict)
|
||||
|
||||
# # this calls `backward` and `on_after_backward` inside the closure
|
||||
# out = super().optimizer_step(*args, **kwargs)
|
||||
|
||||
# # the state dict changed
|
||||
# assert self.check(pre_opt_step_state_dict, self.state_dict(), equal=False)
|
||||
|
||||
# self.opt_step_called = True
|
||||
# return out
|
||||
|
||||
def on_after_backward(self):
|
||||
# should override `optimizer_step` instead but can't with `accumulate_grad_batches`
|
||||
# replace with the above after https://github.com/PyTorchLightning/pytorch-lightning/issues/6910
|
||||
self.opt_step_called = True
|
||||
|
||||
def on_train_batch_start(self, *_):
|
||||
self.start_state_dict = self.state_dict()
|
||||
self.opt_step_called = False
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, *_):
|
||||
end_state_dict = self.state_dict()
|
||||
is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches
|
||||
|
||||
if is_last_batch or self.opt_step_called:
|
||||
assert self.check(self.start_state_dict, end_state_dict, equal=False)
|
||||
else:
|
||||
assert self.check(self.start_state_dict, end_state_dict)
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
max_epochs=2,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=0,
|
||||
limit_test_batches=0,
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
def test_is_overridden():
|
||||
model = BoringModel()
|
||||
datamodule = BoringDataModule()
|
||||
|
||||
# edge cases
|
||||
assert not is_overridden("whatever", None)
|
||||
with pytest.raises(ValueError, match="Expected a parent"):
|
||||
is_overridden("whatever", object())
|
||||
assert not is_overridden("whatever", model)
|
||||
assert not is_overridden("whatever", model, parent=LightningDataModule)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="The parent should define the method"):
|
||||
is_overridden("foo", TestModel())
|
||||
|
||||
# normal usage
|
||||
assert is_overridden("training_step", model)
|
||||
assert is_overridden("train_dataloader", datamodule)
|
||||
|
||||
# `Mock` support
|
||||
mock = Mock(spec=BoringModel, wraps=model)
|
||||
assert is_overridden("training_step", mock)
|
||||
mock = Mock(spec=BoringDataModule, wraps=datamodule)
|
||||
assert is_overridden("train_dataloader", mock)
|
||||
|
||||
# `partial` support
|
||||
model.training_step = partial(model.training_step)
|
||||
assert is_overridden("training_step", model)
|
||||
|
||||
# `_PatchDataLoader.patch_loader_code` support
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def on_fit_start(self):
|
||||
assert is_overridden("train_dataloader", self)
|
||||
self.on_fit_start_called = True
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(fast_dev_run=1)
|
||||
trainer.fit(model, train_dataloader=model.train_dataloader())
|
||||
assert model.on_fit_start_called
|
Loading…
Reference in New Issue