Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
42b5417e9b
commit
99adc45af1
|
@ -186,11 +186,10 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
def _on_evaluation_model_train(self) -> None:
|
||||
"""Sets model to train mode."""
|
||||
model_ref = self.trainer.lightning_module
|
||||
if self.trainer.testing:
|
||||
model_ref.on_test_model_train()
|
||||
self.trainer._call_lightning_module_hook("on_test_model_train")
|
||||
else:
|
||||
model_ref.on_validation_model_train()
|
||||
self.trainer._call_lightning_module_hook("on_validation_model_train")
|
||||
|
||||
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Runs ``on_{validation/test}_end`` hook."""
|
||||
|
|
|
@ -218,9 +218,9 @@ class EvaluationEpochLoop(Loop):
|
|||
the outputs of the step
|
||||
"""
|
||||
if self.trainer.testing:
|
||||
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
|
||||
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
|
||||
else:
|
||||
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())
|
||||
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ class PredictionEpochLoop(Loop):
|
|||
|
||||
self.batch_progress.increment_started()
|
||||
|
||||
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())
|
||||
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
|
||||
|
||||
self.batch_progress.increment_processed()
|
||||
|
||||
|
|
|
@ -180,11 +180,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
# hook
|
||||
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||
model_response = self.trainer._call_lightning_module_hook(
|
||||
response = self.trainer._call_lightning_module_hook(
|
||||
"on_train_batch_start", batch, batch_idx, **extra_kwargs
|
||||
)
|
||||
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||
response = ttp_response if model_response is None else model_response
|
||||
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||
if response == -1:
|
||||
self.batch_progress.increment_processed()
|
||||
raise StopIteration
|
||||
|
|
|
@ -102,7 +102,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||
self.trainer.training_type_plugin.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
|
|
@ -144,12 +144,10 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
)
|
||||
|
||||
if self._zero_grad_fn is not None:
|
||||
with self._profiler.profile("zero_grad"):
|
||||
self._zero_grad_fn()
|
||||
self._zero_grad_fn()
|
||||
|
||||
if self._backward_fn is not None and step_output.closure_loss is not None:
|
||||
with self._profiler.profile("backward"):
|
||||
self._backward_fn(step_output.closure_loss)
|
||||
self._backward_fn(step_output.closure_loss)
|
||||
|
||||
return step_output
|
||||
|
||||
|
@ -320,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
return None
|
||||
|
||||
def backward_fn(loss: Tensor) -> None:
|
||||
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx)
|
||||
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
|
||||
|
||||
# check if model weights are nan
|
||||
if self.trainer._terminate_on_nan:
|
||||
|
@ -362,8 +360,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
train_step_and_backward_closure: the closure function performing the train step and computing the
|
||||
gradients. By default called by the optimizer (if possible)
|
||||
"""
|
||||
lightning_module = self.trainer.lightning_module
|
||||
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
# wraps into LightningOptimizer only for running step
|
||||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
||||
|
@ -371,7 +367,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self.optim_progress.optimizer.step.increment_ready()
|
||||
|
||||
# model hook
|
||||
lightning_module.optimizer_step(
|
||||
self.trainer._call_lightning_module_hook(
|
||||
"optimizer_step",
|
||||
self.trainer.current_epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
|
@ -403,7 +400,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer
|
||||
"""
|
||||
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||
|
||||
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
|
||||
|
@ -427,7 +424,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||
self.trainer.training_type_plugin.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
|
|
@ -164,8 +164,7 @@ class TrainingTypePlugin(ABC):
|
|||
|
||||
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
|
||||
"""Zeros all model parameter's gradients."""
|
||||
model_ref = self.lightning_module
|
||||
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
|
||||
"""Setup a model and multiple optimizers together.
|
||||
|
|
|
@ -31,15 +31,24 @@ class _FxValidator:
|
|||
"on_before_backward": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"backward": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_after_backward": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_before_optimizer_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"optimizer_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_before_zero_grad": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"optimizer_zero_grad": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_init_start": None,
|
||||
"on_init_end": None,
|
||||
"on_fit_start": None,
|
||||
|
@ -160,6 +169,8 @@ class _FxValidator:
|
|||
"configure_callbacks": None,
|
||||
"on_validation_model_eval": None,
|
||||
"on_test_model_eval": None,
|
||||
"on_validation_model_train": None,
|
||||
"on_test_model_train": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1452,7 +1452,7 @@ class Trainer(
|
|||
*args: Any,
|
||||
pl_module: Optional["pl.LightningModule"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> Any:
|
||||
pl_module = pl_module or self.lightning_module
|
||||
|
||||
if pl_module is None:
|
||||
|
@ -1460,7 +1460,7 @@ class Trainer(
|
|||
|
||||
fn = getattr(pl_module, hook_name)
|
||||
if not callable(fn):
|
||||
return None
|
||||
return
|
||||
|
||||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = hook_name
|
||||
|
@ -1479,16 +1479,15 @@ class Trainer(
|
|||
hook_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[Any]:
|
||||
output = None
|
||||
) -> None:
|
||||
if hook_name in ("on_init_start", "on_init_end"):
|
||||
# these `Callback` hooks are the only ones that do not take a lightning module.
|
||||
# we also don't profile bc profiler hasn't been set yet
|
||||
for callback in self.callbacks:
|
||||
fn = getattr(callback, hook_name)
|
||||
if callable(fn):
|
||||
output = fn(self, *args, **kwargs)
|
||||
return output
|
||||
fn(self, *args, **kwargs)
|
||||
return
|
||||
|
||||
pl_module = self.lightning_module
|
||||
if pl_module:
|
||||
|
@ -1500,34 +1499,39 @@ class Trainer(
|
|||
fn = getattr(self, hook_name)
|
||||
if callable(fn):
|
||||
with self.profiler.profile(hook_name):
|
||||
output = fn(*args, **kwargs)
|
||||
fn(*args, **kwargs)
|
||||
else:
|
||||
for callback in self.callbacks:
|
||||
fn = getattr(callback, hook_name)
|
||||
if callable(fn):
|
||||
with self.profiler.profile(hook_name):
|
||||
output = fn(self, self.lightning_module, *args, **kwargs)
|
||||
fn(self, self.lightning_module, *args, **kwargs)
|
||||
|
||||
if pl_module:
|
||||
# restore current_fx when nested context
|
||||
pl_module._current_fx_name = prev_fx_name
|
||||
|
||||
return output
|
||||
|
||||
# TODO: rename to _call_strategy_hook and eventually no longer need this
|
||||
def _call_ttp_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> Any:
|
||||
pl_module = self.lightning_module
|
||||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = hook_name
|
||||
|
||||
fn = getattr(self.training_type_plugin, hook_name)
|
||||
if not callable(fn):
|
||||
return None
|
||||
return
|
||||
|
||||
with self.profiler.profile(hook_name):
|
||||
output = fn(*args, **kwargs)
|
||||
|
||||
# restore current_fx when nested context
|
||||
pl_module._current_fx_name = prev_fx_name
|
||||
|
||||
return output
|
||||
|
||||
# TODO: eventually no longer need this
|
||||
|
@ -1536,15 +1540,21 @@ class Trainer(
|
|||
hook_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[Any]:
|
||||
self.lightning_module._current_fx_name = hook_name
|
||||
fn = getattr(self.training_type_plugin, hook_name)
|
||||
) -> Any:
|
||||
pl_module = self.lightning_module
|
||||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = hook_name
|
||||
|
||||
fn = getattr(self.accelerator, hook_name)
|
||||
if not callable(fn):
|
||||
return None
|
||||
return
|
||||
|
||||
with self.profiler.profile(hook_name):
|
||||
output = fn(*args, **kwargs)
|
||||
|
||||
# restore current_fx when nested context
|
||||
pl_module._current_fx_name = prev_fx_name
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -31,9 +31,9 @@ from tests.models.test_hooks import get_members
|
|||
|
||||
|
||||
def test_fx_validator(tmpdir):
|
||||
funcs_name = sorted(get_members(Callback))
|
||||
funcs_name = get_members(Callback)
|
||||
|
||||
callbacks_func = [
|
||||
callbacks_func = {
|
||||
"on_before_backward",
|
||||
"on_after_backward",
|
||||
"on_before_optimizer_step",
|
||||
|
@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
|
|||
"on_predict_start",
|
||||
"setup",
|
||||
"teardown",
|
||||
]
|
||||
}
|
||||
|
||||
not_supported = [
|
||||
not_supported = {
|
||||
"on_before_accelerator_backend_setup",
|
||||
"on_fit_end",
|
||||
"on_fit_start",
|
||||
|
@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
|
|||
"on_validation_end",
|
||||
"setup",
|
||||
"teardown",
|
||||
]
|
||||
}
|
||||
|
||||
assert funcs_name == sorted(
|
||||
callbacks_func
|
||||
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
|
||||
# Detected new callback function. Need to add its logging permission to FxValidator and update this test
|
||||
assert funcs_name == callbacks_func
|
||||
|
||||
validator = _FxValidator()
|
||||
|
||||
|
@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
|
|||
"prepare_data": "You can't",
|
||||
"configure_callbacks": "You can't",
|
||||
"on_validation_model_eval": "You can't",
|
||||
"on_validation_model_train": "You can't",
|
||||
"summarize": "not managed by the `Trainer",
|
||||
}
|
||||
model = HookedModel(not_supported)
|
||||
|
@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
|
|||
"on_test_dataloader": "You can't",
|
||||
"test_dataloader": "You can't",
|
||||
"on_test_model_eval": "You can't",
|
||||
"on_test_model_train": "You can't",
|
||||
"on_test_end": "You can't",
|
||||
}
|
||||
)
|
||||
|
|
|
@ -50,9 +50,12 @@ def test_default_level_for_hooks_that_support_logging():
|
|||
trainer.state.stage = RunningStage.TRAINING
|
||||
hooks = [
|
||||
"on_before_backward",
|
||||
"backward",
|
||||
"on_after_backward",
|
||||
"on_before_optimizer_step",
|
||||
"optimizer_step",
|
||||
"on_before_zero_grad",
|
||||
"optimizer_zero_grad",
|
||||
"training_step",
|
||||
"training_step_end",
|
||||
"on_batch_start",
|
||||
|
|
Loading…
Reference in New Issue