Stricter `FxValidator` and add hooks (#7874)

* Stricter FxValidator and add hooks

* Update CHANGELOG
This commit is contained in:
Carlos Mocholí 2021-06-08 09:26:05 +02:00 committed by GitHub
parent ce976769ef
commit 3427cb728d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 65 deletions

View File

@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
- Explicitly disallow calling `self.log(on_epoch=False)` during epoch-only or single-call hooks ([#7874](https://github.com/PyTorchLightning/pytorch-lightning/pull/7874))
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`

View File

@ -29,26 +29,26 @@ class FxValidator:
on_fit_end=None,
on_sanity_check_start=None,
on_sanity_check_end=None,
on_train_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_train_start=dict(on_step=(False, ), on_epoch=(True, )),
on_train_end=None,
on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_validation_start=dict(on_step=(False, ), on_epoch=(True, )),
on_validation_end=None,
on_test_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_test_start=dict(on_step=(False, ), on_epoch=(True, )),
on_test_end=None,
on_predict_start=None,
on_predict_end=None,
on_pretrain_routine_start=None,
on_pretrain_routine_end=None,
on_train_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_train_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
on_train_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
on_predict_epoch_start=None,
on_predict_epoch_end=None,
on_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
on_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
on_batch_start=dict(on_step=(False, True), on_epoch=(False, True)),
on_batch_end=dict(on_step=(False, True), on_epoch=(False, True)),
on_train_batch_start=dict(on_step=(False, True), on_epoch=(False, True)),
@ -72,19 +72,26 @@ class FxValidator:
training_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
validation_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
test_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
training_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
training_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
on_before_batch_transfer=None,
transfer_batch_to_device=None,
on_after_batch_transfer=None,
backward=None,
optimizer_step=None,
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
)
def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None:
if fx_name not in self.functions:
@classmethod
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the given function name is allowed to log"""
if fx_name not in cls.functions:
raise RuntimeError(
f'You are trying to `self.log()` inside `{fx_name}` but it is not implemented.'
' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`'
)
allowed = self.functions[fx_name]
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")

View File

@ -349,33 +349,18 @@ def test_log_works_in_val_callback(tmpdir):
def on_validation_start(self, trainer, pl_module):
self.make_logging(
pl_module,
'on_validation_start',
1,
on_steps=self.choices,
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_validation_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_epoch_start(self, trainer, pl_module):
if trainer.validating:
self.make_logging(
pl_module,
'on_epoch_start',
2,
on_steps=self.choices,
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_validation_epoch_start(self, trainer, pl_module):
self.make_logging(
pl_module,
'on_validation_epoch_start',
3,
on_steps=self.choices,
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_validation_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_batch_end(self, trainer, pl_module):
@ -400,17 +385,12 @@ def test_log_works_in_val_callback(tmpdir):
def on_epoch_end(self, trainer, pl_module):
if trainer.validating:
self.make_logging(
pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_validation_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module,
'on_validation_epoch_end',
9,
on_steps=[False],
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
class TestModel(BoringModel):
@ -558,18 +538,11 @@ def test_log_works_in_test_callback(tmpdir):
}
def on_test_start(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)
self.make_logging(pl_module, 'on_test_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices)
def on_test_epoch_start(self, trainer, pl_module):
self.make_logging(
pl_module,
'on_test_epoch_start',
3,
on_steps=self.choices,
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_test_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@ -589,7 +562,7 @@ def test_log_works_in_test_callback(tmpdir):
def on_test_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
max_epochs = 2

View File

@ -359,7 +359,8 @@ def test_fx_validator(tmpdir):
# This summarizes where and what is currently possible to log using `self.log`
is_stage = "train" in func_name or "test" in func_name or "validation" in func_name
is_start = "start" in func_name or "batch" in func_name
on_step = is_stage and is_start
is_epoch = "epoch" in func_name
on_step = is_stage and not is_start and not is_epoch
on_epoch = True
# creating allowed condition
allowed = (

View File

@ -399,22 +399,17 @@ def test_log_works_in_train_callback(tmpdir):
def on_train_start(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
pl_module, 'on_train_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_epoch_start(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_train_epoch_start(self, trainer, pl_module):
self.make_logging(
pl_module,
'on_train_epoch_start',
3,
on_steps=self.choices,
on_epochs=self.choices,
prob_bars=self.choices
pl_module, 'on_train_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_batch_end(self, trainer, pl_module):
@ -438,13 +433,11 @@ def test_log_works_in_train_callback(tmpdir):
def on_train_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
)
self.make_logging(pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices)
class TestModel(BoringModel):