Stricter `FxValidator` and add hooks (#7874)
* Stricter FxValidator and add hooks * Update CHANGELOG
This commit is contained in:
parent
ce976769ef
commit
3427cb728d
|
@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
|
||||
|
||||
|
||||
- Explicitly disallow calling `self.log(on_epoch=False)` during epoch-only or single-call hooks ([#7874](https://github.com/PyTorchLightning/pytorch-lightning/pull/7874))
|
||||
|
||||
|
||||
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
|
||||
|
||||
|
||||
|
|
|
@ -29,26 +29,26 @@ class FxValidator:
|
|||
on_fit_end=None,
|
||||
on_sanity_check_start=None,
|
||||
on_sanity_check_end=None,
|
||||
on_train_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_train_start=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_train_end=None,
|
||||
on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_validation_start=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_validation_end=None,
|
||||
on_test_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_test_start=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_test_end=None,
|
||||
on_predict_start=None,
|
||||
on_predict_end=None,
|
||||
on_pretrain_routine_start=None,
|
||||
on_pretrain_routine_end=None,
|
||||
on_train_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_train_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
on_train_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
|
||||
on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
|
||||
on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
|
||||
on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_predict_epoch_start=None,
|
||||
on_predict_epoch_end=None,
|
||||
on_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )),
|
||||
on_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_batch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_batch_end=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_train_batch_start=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
|
@ -72,19 +72,26 @@ class FxValidator:
|
|||
training_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
validation_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
test_step_end=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
training_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)),
|
||||
training_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )),
|
||||
on_before_batch_transfer=None,
|
||||
transfer_batch_to_device=None,
|
||||
on_after_batch_transfer=None,
|
||||
backward=None,
|
||||
optimizer_step=None,
|
||||
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
|
||||
)
|
||||
|
||||
def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None:
|
||||
if fx_name not in self.functions:
|
||||
@classmethod
|
||||
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
|
||||
"""Check if the given function name is allowed to log"""
|
||||
if fx_name not in cls.functions:
|
||||
raise RuntimeError(
|
||||
f'You are trying to `self.log()` inside `{fx_name}` but it is not implemented.'
|
||||
' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`'
|
||||
)
|
||||
allowed = self.functions[fx_name]
|
||||
allowed = cls.functions[fx_name]
|
||||
if allowed is None:
|
||||
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")
|
||||
|
||||
|
|
|
@ -349,33 +349,18 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_validation_start',
|
||||
1,
|
||||
on_steps=self.choices,
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_validation_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
if trainer.validating:
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_epoch_start',
|
||||
2,
|
||||
on_steps=self.choices,
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_validation_epoch_start',
|
||||
3,
|
||||
on_steps=self.choices,
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_validation_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
|
@ -400,17 +385,12 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
def on_epoch_end(self, trainer, pl_module):
|
||||
if trainer.validating:
|
||||
self.make_logging(
|
||||
pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
|
||||
pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_validation_epoch_end',
|
||||
9,
|
||||
on_steps=[False],
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
@ -558,18 +538,11 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
}
|
||||
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
|
||||
)
|
||||
self.make_logging(pl_module, 'on_test_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices)
|
||||
|
||||
def on_test_epoch_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_test_epoch_start',
|
||||
3,
|
||||
on_steps=self.choices,
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_test_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -589,7 +562,7 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
|
||||
def on_test_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
|
||||
pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
max_epochs = 2
|
||||
|
|
|
@ -359,7 +359,8 @@ def test_fx_validator(tmpdir):
|
|||
# This summarizes where and what is currently possible to log using `self.log`
|
||||
is_stage = "train" in func_name or "test" in func_name or "validation" in func_name
|
||||
is_start = "start" in func_name or "batch" in func_name
|
||||
on_step = is_stage and is_start
|
||||
is_epoch = "epoch" in func_name
|
||||
on_step = is_stage and not is_start and not is_epoch
|
||||
on_epoch = True
|
||||
# creating allowed condition
|
||||
allowed = (
|
||||
|
|
|
@ -399,22 +399,17 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
|
||||
pl_module, 'on_train_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
|
||||
pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module,
|
||||
'on_train_epoch_start',
|
||||
3,
|
||||
on_steps=self.choices,
|
||||
on_epochs=self.choices,
|
||||
prob_bars=self.choices
|
||||
pl_module, 'on_train_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
|
@ -438,13 +433,11 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
|
||||
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
|
||||
)
|
||||
self.make_logging(pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
|
|
Loading…
Reference in New Issue