`is_overridden` improvements (#7918)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-06-11 13:47:00 +02:00 committed by GitHub
parent 9e932f4dfd
commit ac4eb0a06a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 47 deletions

View File

@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
- Added support for passing any class to `is_overridden` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
- Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195))
@ -172,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))

View File

@ -201,9 +201,9 @@ class EvaluationLoop(object):
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden('test_epoch_end', model=model)
return is_overridden('test_epoch_end', model)
else:
return is_overridden('validation_epoch_end', model=model)
return is_overridden('validation_epoch_end', model)
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
# inform logger the batch loop has finished
@ -216,12 +216,12 @@ class EvaluationLoop(object):
model._current_dataloader_idx = None
if self.trainer.testing:
if is_overridden('test_epoch_end', model=model):
if is_overridden('test_epoch_end', model):
model._current_fx_name = 'test_epoch_end'
model.test_epoch_end(outputs)
else:
if is_overridden('validation_epoch_end', model=model):
if is_overridden('validation_epoch_end', model):
model._current_fx_name = 'validation_epoch_end'
model.validation_epoch_end(outputs)

View File

@ -216,10 +216,10 @@ class TrainLoop:
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
lightning_module = self.trainer.lightning_module
if is_overridden("training_epoch_end", model=lightning_module):
if is_overridden("training_epoch_end", lightning_module):
return True
if is_overridden("on_train_epoch_end", model=lightning_module):
if is_overridden("on_train_epoch_end", lightning_module):
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
if is_param_in_hook_signature(model_hook_fx, "outputs"):
return True
@ -540,7 +540,7 @@ class TrainLoop:
# get the model and call model.training_epoch_end
model = self.trainer.lightning_module
if is_overridden('training_epoch_end', model=model):
if is_overridden('training_epoch_end', model):
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'

View File

@ -11,33 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from functools import partial
from typing import Optional, Type, Union
from unittest.mock import Mock
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
def is_overridden(
method_name: str,
instance: Optional[object] = None,
parent: Optional[Type[object]] = None,
model: Optional[Union[LightningModule, LightningDataModule]] = None,
) -> bool:
if model is not None and instance is None:
rank_zero_deprecation(
'`is_overriden(model=...)` has been deprecated and will be removed in v1.6.'
'Please use `is_overriden(instance=...)`'
)
instance = model
if not hasattr(model, method_name) or not hasattr(super_object, method_name):
# in case of calling deprecated method
if instance is None:
# if `self.lightning_module` was passed as instance, it can be `None`
return False
instance_attr = getattr(model, method_name)
if not instance_attr:
return False
super_attr = getattr(super_object, method_name)
if parent is None:
if isinstance(instance, LightningModule):
parent = LightningModule
elif isinstance(instance, LightningDataModule):
parent = LightningDataModule
if parent is None:
raise ValueError("Expected a parent")
# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = instance_attr.__code__ is not super_attr.__code__
return is_overridden
instance_attr = getattr(instance, method_name, None)
# `Mock(wraps=...)` support
if isinstance(instance_attr, Mock):
# access the wrapped function
instance_attr = instance_attr._mock_wraps
# `partial` support
elif isinstance(instance_attr, partial):
instance_attr = instance_attr.func
if instance_attr is None:
return False
parent_attr = getattr(parent, method_name, None)
if parent_attr is None:
raise ValueError("The parent should define the method")
# cannot pickle `__code__` so cannot verify if `PatchDataloader`
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
instance_code = getattr(instance_attr, 'patch_loader_code', None) or instance_attr.__code__
parent_code = parent_attr.__code__
return instance_code != parent_code

View File

@ -17,6 +17,7 @@ import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
@ -175,6 +176,14 @@ def test_v1_6_0_datamodule_hooks_calls(tmpdir):
assert dm.teardown_calls == ['validate', 'test']
def test_v1_6_0_is_overridden_model():
model = BoringModel()
with pytest.deprecated_call(match="and will be removed in v1.6"):
assert is_overridden("validation_step", model=model)
with pytest.deprecated_call(match="and will be removed in v1.6"):
assert not is_overridden("foo", model=model)
def test_v1_6_0_early_stopping_monitor(tmpdir):
with pytest.deprecated_call(
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."

View File

@ -37,25 +37,28 @@ def test_trainer_fn_while_running(tmpdir, extra_params):
def __init__(self, expected_fn, expected_stage):
super().__init__()
self.expected_state = expected_fn
self.expected_fn = expected_fn
self.expected_stage = expected_stage
self.lr = 0.1
def on_batch_start(self, *_):
assert self.trainer.state == TrainerState(
status=TrainerStatus.RUNNING, fn=self.expected_fn, stage=self.expected_stage
)
def on_train_batch_start(self, *_):
assert self.trainer.state.status == TrainerStatus.RUNNING
assert self.trainer.state.fn == self.expected_fn
assert self.trainer.training
def on_sanity_check_start(self, *_):
assert self.trainer.state.status == TrainerStatus.RUNNING
assert self.trainer.state.fn == self.expected_fn
assert self.trainer.sanity_checking
def on_validation_batch_start(self, *_):
assert self.trainer.state.status == TrainerStatus.RUNNING
assert self.trainer.state.fn == self.expected_fn
assert self.trainer.validating or self.trainer.sanity_checking
def on_test_batch_start(self, *_):
assert self.trainer.state.status == TrainerStatus.RUNNING
assert self.trainer.state.fn == self.expected_fn
assert self.trainer.testing
model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING)

View File

@ -232,29 +232,66 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches):
""" Verify optimizer.step() applied to last batch while grad accumulation """
class CurrentModel(BoringModel):
class TestModel(BoringModel):
def on_batch_start(self, *_):
self.on_train_batch_start_state_dict = self.state_dict()
def state_dict(self, *args, **kwargs):
return deepcopy(super().state_dict(*args, **kwargs))
def on_batch_end(self, outputs, batch, batch_idx, *_):
self.on_train_batch_start_end_dict = self.state_dict()
for key in self.on_train_batch_start_end_dict.keys():
equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
if (batch_idx + 1) == self.trainer.num_training_batches:
assert equal
else:
assert not equal
def check(self, d1, d2, equal=True):
keys = d1.keys() | d2.keys()
values = [torch.equal(d1[k], d2[k]) for k in keys]
return all(values) if equal else not any(values)
model = CurrentModel()
def backward(self, *args, **kwargs) -> None:
pre_bwd_state_dict = self.state_dict()
assert self.check(self.start_state_dict, pre_bwd_state_dict)
out = super().backward(*args, **kwargs)
# state dict is equal, just the gradients changed
assert self.check(pre_bwd_state_dict, self.state_dict())
return out
# def optimizer_step(self, *args, **kwargs):
# pre_opt_step_state_dict = self.state_dict()
# assert self.check(self.start_state_dict, pre_opt_step_state_dict)
# # this calls `backward` and `on_after_backward` inside the closure
# out = super().optimizer_step(*args, **kwargs)
# # the state dict changed
# assert self.check(pre_opt_step_state_dict, self.state_dict(), equal=False)
# self.opt_step_called = True
# return out
def on_after_backward(self):
# should override `optimizer_step` instead but can't with `accumulate_grad_batches`
# replace with the above after https://github.com/PyTorchLightning/pytorch-lightning/issues/6910
self.opt_step_called = True
def on_train_batch_start(self, *_):
self.start_state_dict = self.state_dict()
self.opt_step_called = False
def on_train_batch_end(self, outputs, batch, batch_idx, *_):
end_state_dict = self.state_dict()
is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches
if is_last_batch or self.opt_step_called:
assert self.check(self.start_state_dict, end_state_dict, equal=False)
else:
assert self.check(self.start_state_dict, end_state_dict)
model = TestModel()
trainer = Trainer(
accumulate_grad_batches=accumulate_grad_batches,
max_epochs=2,
limit_train_batches=limit_train_batches,
limit_val_batches=0,
limit_test_batches=0,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
)
trainer.fit(model)

View File

@ -0,0 +1,67 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from unittest.mock import Mock
import pytest
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
def test_is_overridden():
model = BoringModel()
datamodule = BoringDataModule()
# edge cases
assert not is_overridden("whatever", None)
with pytest.raises(ValueError, match="Expected a parent"):
is_overridden("whatever", object())
assert not is_overridden("whatever", model)
assert not is_overridden("whatever", model, parent=LightningDataModule)
class TestModel(BoringModel):
def foo(self):
pass
with pytest.raises(ValueError, match="The parent should define the method"):
is_overridden("foo", TestModel())
# normal usage
assert is_overridden("training_step", model)
assert is_overridden("train_dataloader", datamodule)
# `Mock` support
mock = Mock(spec=BoringModel, wraps=model)
assert is_overridden("training_step", mock)
mock = Mock(spec=BoringDataModule, wraps=datamodule)
assert is_overridden("train_dataloader", mock)
# `partial` support
model.training_step = partial(model.training_step)
assert is_overridden("training_step", model)
# `_PatchDataLoader.patch_loader_code` support
class TestModel(BoringModel):
def on_fit_start(self):
assert is_overridden("train_dataloader", self)
self.on_fit_start_called = True
model = TestModel()
trainer = Trainer(fast_dev_run=1)
trainer.fit(model, train_dataloader=model.train_dataloader())
assert model.on_fit_start_called