Follow-up changes to #10575 (#10957)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-12-07 15:27:52 +01:00 committed by GitHub
parent 42b5417e9b
commit 99adc45af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 46 deletions

View File

@ -186,11 +186,10 @@ class EvaluationLoop(DataLoaderLoop):
def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_train()
self.trainer._call_lightning_module_hook("on_test_model_train")
else:
model_ref.on_validation_model_train()
self.trainer._call_lightning_module_hook("on_validation_model_train")
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""

View File

@ -218,9 +218,9 @@ class EvaluationEpochLoop(Loop):
the outputs of the step
"""
if self.trainer.testing:
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
else:
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
return output

View File

@ -130,7 +130,7 @@ class PredictionEpochLoop(Loop):
self.batch_progress.increment_started()
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
self.batch_progress.increment_processed()

View File

@ -180,11 +180,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
# hook
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
model_response = self.trainer._call_lightning_module_hook(
response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = ttp_response if model_response is None else model_response
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

View File

@ -102,7 +102,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
)
# manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
del step_kwargs

View File

@ -144,12 +144,10 @@ class Closure(AbstractClosure[ClosureResult]):
)
if self._zero_grad_fn is not None:
with self._profiler.profile("zero_grad"):
self._zero_grad_fn()
self._zero_grad_fn()
if self._backward_fn is not None and step_output.closure_loss is not None:
with self._profiler.profile("backward"):
self._backward_fn(step_output.closure_loss)
self._backward_fn(step_output.closure_loss)
return step_output
@ -320,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
return None
def backward_fn(loss: Tensor) -> None:
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx)
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
# check if model weights are nan
if self.trainer._terminate_on_nan:
@ -362,8 +360,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
lightning_module = self.trainer.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
@ -371,7 +367,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self.optim_progress.optimizer.step.increment_ready()
# model hook
lightning_module.optimizer_step(
self.trainer._call_lightning_module_hook(
"optimizer_step",
self.trainer.current_epoch,
batch_idx,
optimizer,
@ -403,7 +400,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
@ -427,7 +424,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
)
# manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
del step_kwargs

View File

@ -164,8 +164,7 @@ class TrainingTypePlugin(ABC):
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.

View File

@ -31,15 +31,24 @@ class _FxValidator:
"on_before_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"optimizer_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None,
"on_init_end": None,
"on_fit_start": None,
@ -160,6 +169,8 @@ class _FxValidator:
"configure_callbacks": None,
"on_validation_model_eval": None,
"on_test_model_eval": None,
"on_validation_model_train": None,
"on_test_model_train": None,
}
@classmethod

View File

@ -1452,7 +1452,7 @@ class Trainer(
*args: Any,
pl_module: Optional["pl.LightningModule"] = None,
**kwargs: Any,
):
) -> Any:
pl_module = pl_module or self.lightning_module
if pl_module is None:
@ -1460,7 +1460,7 @@ class Trainer(
fn = getattr(pl_module, hook_name)
if not callable(fn):
return None
return
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
@ -1479,16 +1479,15 @@ class Trainer(
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Optional[Any]:
output = None
) -> None:
if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
output = fn(self, *args, **kwargs)
return output
fn(self, *args, **kwargs)
return
pl_module = self.lightning_module
if pl_module:
@ -1500,34 +1499,39 @@ class Trainer(
fn = getattr(self, hook_name)
if callable(fn):
with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)
fn(*args, **kwargs)
else:
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with self.profiler.profile(hook_name):
output = fn(self, self.lightning_module, *args, **kwargs)
fn(self, self.lightning_module, *args, **kwargs)
if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output
# TODO: rename to _call_strategy_hook and eventually no longer need this
def _call_ttp_hook(
self,
hook_name: str,
*args: Any,
**kwargs: Any,
):
) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
fn = getattr(self.training_type_plugin, hook_name)
if not callable(fn):
return None
return
with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output
# TODO: eventually no longer need this
@ -1536,15 +1540,21 @@ class Trainer(
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Optional[Any]:
self.lightning_module._current_fx_name = hook_name
fn = getattr(self.training_type_plugin, hook_name)
) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
fn = getattr(self.accelerator, hook_name)
if not callable(fn):
return None
return
with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output
@staticmethod

View File

@ -31,9 +31,9 @@ from tests.models.test_hooks import get_members
def test_fx_validator(tmpdir):
funcs_name = sorted(get_members(Callback))
funcs_name = get_members(Callback)
callbacks_func = [
callbacks_func = {
"on_before_backward",
"on_after_backward",
"on_before_optimizer_step",
@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
"on_predict_start",
"setup",
"teardown",
]
}
not_supported = [
not_supported = {
"on_before_accelerator_backend_setup",
"on_fit_end",
"on_fit_start",
@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
"on_validation_end",
"setup",
"teardown",
]
}
assert funcs_name == sorted(
callbacks_func
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
# Detected new callback function. Need to add its logging permission to FxValidator and update this test
assert funcs_name == callbacks_func
validator = _FxValidator()
@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
"prepare_data": "You can't",
"configure_callbacks": "You can't",
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
"on_test_dataloader": "You can't",
"test_dataloader": "You can't",
"on_test_model_eval": "You can't",
"on_test_model_train": "You can't",
"on_test_end": "You can't",
}
)

View File

@ -50,9 +50,12 @@ def test_default_level_for_hooks_that_support_logging():
trainer.state.stage = RunningStage.TRAINING
hooks = [
"on_before_backward",
"backward",
"on_after_backward",
"on_before_optimizer_step",
"optimizer_step",
"on_before_zero_grad",
"optimizer_zero_grad",
"training_step",
"training_step_end",
"on_batch_start",