Follow-up changes to #10575 (#10957)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-12-07 15:27:52 +01:00 committed by GitHub
parent 42b5417e9b
commit 99adc45af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 46 deletions

View File

@ -186,11 +186,10 @@ class EvaluationLoop(DataLoaderLoop):
def _on_evaluation_model_train(self) -> None: def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode.""" """Sets model to train mode."""
model_ref = self.trainer.lightning_module
if self.trainer.testing: if self.trainer.testing:
model_ref.on_test_model_train() self.trainer._call_lightning_module_hook("on_test_model_train")
else: else:
model_ref.on_validation_model_train() self.trainer._call_lightning_module_hook("on_validation_model_train")
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook.""" """Runs ``on_{validation/test}_end`` hook."""

View File

@ -218,9 +218,9 @@ class EvaluationEpochLoop(Loop):
the outputs of the step the outputs of the step
""" """
if self.trainer.testing: if self.trainer.testing:
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values()) output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
else: else:
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values()) output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
return output return output

View File

@ -130,7 +130,7 @@ class PredictionEpochLoop(Loop):
self.batch_progress.increment_started() self.batch_progress.increment_started()
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values()) predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
self.batch_progress.increment_processed() self.batch_progress.increment_processed()

View File

@ -180,11 +180,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
# hook # hook
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs) self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
model_response = self.trainer._call_lightning_module_hook( response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs "on_train_batch_start", batch, batch_idx, **extra_kwargs
) )
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = ttp_response if model_response is None else model_response
if response == -1: if response == -1:
self.batch_progress.increment_processed() self.batch_progress.increment_processed()
raise StopIteration raise StopIteration

View File

@ -102,7 +102,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
) )
# manually capture logged metrics # manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values()) training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step() self.trainer.training_type_plugin.post_training_step()
del step_kwargs del step_kwargs

View File

@ -144,12 +144,10 @@ class Closure(AbstractClosure[ClosureResult]):
) )
if self._zero_grad_fn is not None: if self._zero_grad_fn is not None:
with self._profiler.profile("zero_grad"): self._zero_grad_fn()
self._zero_grad_fn()
if self._backward_fn is not None and step_output.closure_loss is not None: if self._backward_fn is not None and step_output.closure_loss is not None:
with self._profiler.profile("backward"): self._backward_fn(step_output.closure_loss)
self._backward_fn(step_output.closure_loss)
return step_output return step_output
@ -320,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
return None return None
def backward_fn(loss: Tensor) -> None: def backward_fn(loss: Tensor) -> None:
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx) self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
# check if model weights are nan # check if model weights are nan
if self.trainer._terminate_on_nan: if self.trainer._terminate_on_nan:
@ -362,8 +360,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
train_step_and_backward_closure: the closure function performing the train step and computing the train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible) gradients. By default called by the optimizer (if possible)
""" """
lightning_module = self.trainer.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# wraps into LightningOptimizer only for running step # wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
@ -371,7 +367,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self.optim_progress.optimizer.step.increment_ready() self.optim_progress.optimizer.step.increment_ready()
# model hook # model hook
lightning_module.optimizer_step( self.trainer._call_lightning_module_hook(
"optimizer_step",
self.trainer.current_epoch, self.trainer.current_epoch,
batch_idx, batch_idx,
optimizer, optimizer,
@ -403,7 +400,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
optimizer: the current optimizer optimizer: the current optimizer
opt_idx: the index of the current optimizer opt_idx: the index of the current optimizer
""" """
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed() self.optim_progress.optimizer.zero_grad.increment_completed()
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
@ -427,7 +424,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
) )
# manually capture logged metrics # manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values()) training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step() self.trainer.training_type_plugin.post_training_step()
del step_kwargs del step_kwargs

View File

@ -164,8 +164,7 @@ class TrainingTypePlugin(ABC):
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients.""" """Zeros all model parameter's gradients."""
model_ref = self.lightning_module self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together. """Setup a model and multiple optimizers together.

View File

@ -31,15 +31,24 @@ class _FxValidator:
"on_before_backward": _LogOptions( "on_before_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
), ),
"backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_after_backward": _LogOptions( "on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
), ),
"on_before_optimizer_step": _LogOptions( "on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
), ),
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_zero_grad": _LogOptions( "on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
), ),
"optimizer_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None, "on_init_start": None,
"on_init_end": None, "on_init_end": None,
"on_fit_start": None, "on_fit_start": None,
@ -160,6 +169,8 @@ class _FxValidator:
"configure_callbacks": None, "configure_callbacks": None,
"on_validation_model_eval": None, "on_validation_model_eval": None,
"on_test_model_eval": None, "on_test_model_eval": None,
"on_validation_model_train": None,
"on_test_model_train": None,
} }
@classmethod @classmethod

View File

@ -1452,7 +1452,7 @@ class Trainer(
*args: Any, *args: Any,
pl_module: Optional["pl.LightningModule"] = None, pl_module: Optional["pl.LightningModule"] = None,
**kwargs: Any, **kwargs: Any,
): ) -> Any:
pl_module = pl_module or self.lightning_module pl_module = pl_module or self.lightning_module
if pl_module is None: if pl_module is None:
@ -1460,7 +1460,7 @@ class Trainer(
fn = getattr(pl_module, hook_name) fn = getattr(pl_module, hook_name)
if not callable(fn): if not callable(fn):
return None return
prev_fx_name = pl_module._current_fx_name prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name pl_module._current_fx_name = hook_name
@ -1479,16 +1479,15 @@ class Trainer(
hook_name: str, hook_name: str,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Optional[Any]: ) -> None:
output = None
if hook_name in ("on_init_start", "on_init_end"): if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module. # these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet # we also don't profile bc profiler hasn't been set yet
for callback in self.callbacks: for callback in self.callbacks:
fn = getattr(callback, hook_name) fn = getattr(callback, hook_name)
if callable(fn): if callable(fn):
output = fn(self, *args, **kwargs) fn(self, *args, **kwargs)
return output return
pl_module = self.lightning_module pl_module = self.lightning_module
if pl_module: if pl_module:
@ -1500,34 +1499,39 @@ class Trainer(
fn = getattr(self, hook_name) fn = getattr(self, hook_name)
if callable(fn): if callable(fn):
with self.profiler.profile(hook_name): with self.profiler.profile(hook_name):
output = fn(*args, **kwargs) fn(*args, **kwargs)
else: else:
for callback in self.callbacks: for callback in self.callbacks:
fn = getattr(callback, hook_name) fn = getattr(callback, hook_name)
if callable(fn): if callable(fn):
with self.profiler.profile(hook_name): with self.profiler.profile(hook_name):
output = fn(self, self.lightning_module, *args, **kwargs) fn(self, self.lightning_module, *args, **kwargs)
if pl_module: if pl_module:
# restore current_fx when nested context # restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name pl_module._current_fx_name = prev_fx_name
return output
# TODO: rename to _call_strategy_hook and eventually no longer need this # TODO: rename to _call_strategy_hook and eventually no longer need this
def _call_ttp_hook( def _call_ttp_hook(
self, self,
hook_name: str, hook_name: str,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
): ) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
fn = getattr(self.training_type_plugin, hook_name) fn = getattr(self.training_type_plugin, hook_name)
if not callable(fn): if not callable(fn):
return None return
with self.profiler.profile(hook_name): with self.profiler.profile(hook_name):
output = fn(*args, **kwargs) output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output return output
# TODO: eventually no longer need this # TODO: eventually no longer need this
@ -1536,15 +1540,21 @@ class Trainer(
hook_name: str, hook_name: str,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Optional[Any]: ) -> Any:
self.lightning_module._current_fx_name = hook_name pl_module = self.lightning_module
fn = getattr(self.training_type_plugin, hook_name) prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
fn = getattr(self.accelerator, hook_name)
if not callable(fn): if not callable(fn):
return None return
with self.profiler.profile(hook_name): with self.profiler.profile(hook_name):
output = fn(*args, **kwargs) output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output return output
@staticmethod @staticmethod

View File

@ -31,9 +31,9 @@ from tests.models.test_hooks import get_members
def test_fx_validator(tmpdir): def test_fx_validator(tmpdir):
funcs_name = sorted(get_members(Callback)) funcs_name = get_members(Callback)
callbacks_func = [ callbacks_func = {
"on_before_backward", "on_before_backward",
"on_after_backward", "on_after_backward",
"on_before_optimizer_step", "on_before_optimizer_step",
@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
"on_predict_start", "on_predict_start",
"setup", "setup",
"teardown", "teardown",
] }
not_supported = [ not_supported = {
"on_before_accelerator_backend_setup", "on_before_accelerator_backend_setup",
"on_fit_end", "on_fit_end",
"on_fit_start", "on_fit_start",
@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
"on_validation_end", "on_validation_end",
"setup", "setup",
"teardown", "teardown",
] }
assert funcs_name == sorted( # Detected new callback function. Need to add its logging permission to FxValidator and update this test
callbacks_func assert funcs_name == callbacks_func
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
validator = _FxValidator() validator = _FxValidator()
@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
"prepare_data": "You can't", "prepare_data": "You can't",
"configure_callbacks": "You can't", "configure_callbacks": "You can't",
"on_validation_model_eval": "You can't", "on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"summarize": "not managed by the `Trainer", "summarize": "not managed by the `Trainer",
} }
model = HookedModel(not_supported) model = HookedModel(not_supported)
@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
"on_test_dataloader": "You can't", "on_test_dataloader": "You can't",
"test_dataloader": "You can't", "test_dataloader": "You can't",
"on_test_model_eval": "You can't", "on_test_model_eval": "You can't",
"on_test_model_train": "You can't",
"on_test_end": "You can't", "on_test_end": "You can't",
} }
) )

View File

@ -50,9 +50,12 @@ def test_default_level_for_hooks_that_support_logging():
trainer.state.stage = RunningStage.TRAINING trainer.state.stage = RunningStage.TRAINING
hooks = [ hooks = [
"on_before_backward", "on_before_backward",
"backward",
"on_after_backward", "on_after_backward",
"on_before_optimizer_step", "on_before_optimizer_step",
"optimizer_step",
"on_before_zero_grad", "on_before_zero_grad",
"optimizer_zero_grad",
"training_step", "training_step",
"training_step_end", "training_step_end",
"on_batch_start", "on_batch_start",