Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
42b5417e9b
commit
99adc45af1
|
@ -186,11 +186,10 @@ class EvaluationLoop(DataLoaderLoop):
|
||||||
|
|
||||||
def _on_evaluation_model_train(self) -> None:
|
def _on_evaluation_model_train(self) -> None:
|
||||||
"""Sets model to train mode."""
|
"""Sets model to train mode."""
|
||||||
model_ref = self.trainer.lightning_module
|
|
||||||
if self.trainer.testing:
|
if self.trainer.testing:
|
||||||
model_ref.on_test_model_train()
|
self.trainer._call_lightning_module_hook("on_test_model_train")
|
||||||
else:
|
else:
|
||||||
model_ref.on_validation_model_train()
|
self.trainer._call_lightning_module_hook("on_validation_model_train")
|
||||||
|
|
||||||
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
|
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Runs ``on_{validation/test}_end`` hook."""
|
"""Runs ``on_{validation/test}_end`` hook."""
|
||||||
|
|
|
@ -218,9 +218,9 @@ class EvaluationEpochLoop(Loop):
|
||||||
the outputs of the step
|
the outputs of the step
|
||||||
"""
|
"""
|
||||||
if self.trainer.testing:
|
if self.trainer.testing:
|
||||||
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
|
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
|
||||||
else:
|
else:
|
||||||
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())
|
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,7 @@ class PredictionEpochLoop(Loop):
|
||||||
|
|
||||||
self.batch_progress.increment_started()
|
self.batch_progress.increment_started()
|
||||||
|
|
||||||
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())
|
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
|
||||||
|
|
||||||
self.batch_progress.increment_processed()
|
self.batch_progress.increment_processed()
|
||||||
|
|
||||||
|
|
|
@ -180,11 +180,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
||||||
|
|
||||||
# hook
|
# hook
|
||||||
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||||
model_response = self.trainer._call_lightning_module_hook(
|
response = self.trainer._call_lightning_module_hook(
|
||||||
"on_train_batch_start", batch, batch_idx, **extra_kwargs
|
"on_train_batch_start", batch, batch_idx, **extra_kwargs
|
||||||
)
|
)
|
||||||
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||||
response = ttp_response if model_response is None else model_response
|
|
||||||
if response == -1:
|
if response == -1:
|
||||||
self.batch_progress.increment_processed()
|
self.batch_progress.increment_processed()
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
|
@ -102,7 +102,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
||||||
)
|
)
|
||||||
|
|
||||||
# manually capture logged metrics
|
# manually capture logged metrics
|
||||||
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
|
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||||
self.trainer.training_type_plugin.post_training_step()
|
self.trainer.training_type_plugin.post_training_step()
|
||||||
|
|
||||||
del step_kwargs
|
del step_kwargs
|
||||||
|
|
|
@ -144,12 +144,10 @@ class Closure(AbstractClosure[ClosureResult]):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._zero_grad_fn is not None:
|
if self._zero_grad_fn is not None:
|
||||||
with self._profiler.profile("zero_grad"):
|
self._zero_grad_fn()
|
||||||
self._zero_grad_fn()
|
|
||||||
|
|
||||||
if self._backward_fn is not None and step_output.closure_loss is not None:
|
if self._backward_fn is not None and step_output.closure_loss is not None:
|
||||||
with self._profiler.profile("backward"):
|
self._backward_fn(step_output.closure_loss)
|
||||||
self._backward_fn(step_output.closure_loss)
|
|
||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
|
@ -320,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def backward_fn(loss: Tensor) -> None:
|
def backward_fn(loss: Tensor) -> None:
|
||||||
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx)
|
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
|
||||||
|
|
||||||
# check if model weights are nan
|
# check if model weights are nan
|
||||||
if self.trainer._terminate_on_nan:
|
if self.trainer._terminate_on_nan:
|
||||||
|
@ -362,8 +360,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||||
train_step_and_backward_closure: the closure function performing the train step and computing the
|
train_step_and_backward_closure: the closure function performing the train step and computing the
|
||||||
gradients. By default called by the optimizer (if possible)
|
gradients. By default called by the optimizer (if possible)
|
||||||
"""
|
"""
|
||||||
lightning_module = self.trainer.lightning_module
|
|
||||||
|
|
||||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||||
# wraps into LightningOptimizer only for running step
|
# wraps into LightningOptimizer only for running step
|
||||||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
||||||
|
@ -371,7 +367,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||||
self.optim_progress.optimizer.step.increment_ready()
|
self.optim_progress.optimizer.step.increment_ready()
|
||||||
|
|
||||||
# model hook
|
# model hook
|
||||||
lightning_module.optimizer_step(
|
self.trainer._call_lightning_module_hook(
|
||||||
|
"optimizer_step",
|
||||||
self.trainer.current_epoch,
|
self.trainer.current_epoch,
|
||||||
batch_idx,
|
batch_idx,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -403,7 +400,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||||
optimizer: the current optimizer
|
optimizer: the current optimizer
|
||||||
opt_idx: the index of the current optimizer
|
opt_idx: the index of the current optimizer
|
||||||
"""
|
"""
|
||||||
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||||
|
|
||||||
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
|
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
|
||||||
|
@ -427,7 +424,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||||
)
|
)
|
||||||
|
|
||||||
# manually capture logged metrics
|
# manually capture logged metrics
|
||||||
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
|
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||||
self.trainer.training_type_plugin.post_training_step()
|
self.trainer.training_type_plugin.post_training_step()
|
||||||
|
|
||||||
del step_kwargs
|
del step_kwargs
|
||||||
|
|
|
@ -164,8 +164,7 @@ class TrainingTypePlugin(ABC):
|
||||||
|
|
||||||
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
|
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
|
||||||
"""Zeros all model parameter's gradients."""
|
"""Zeros all model parameter's gradients."""
|
||||||
model_ref = self.lightning_module
|
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
|
||||||
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
|
|
||||||
|
|
||||||
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
|
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
|
||||||
"""Setup a model and multiple optimizers together.
|
"""Setup a model and multiple optimizers together.
|
||||||
|
|
|
@ -31,15 +31,24 @@ class _FxValidator:
|
||||||
"on_before_backward": _LogOptions(
|
"on_before_backward": _LogOptions(
|
||||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
),
|
),
|
||||||
|
"backward": _LogOptions(
|
||||||
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
|
),
|
||||||
"on_after_backward": _LogOptions(
|
"on_after_backward": _LogOptions(
|
||||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
),
|
),
|
||||||
"on_before_optimizer_step": _LogOptions(
|
"on_before_optimizer_step": _LogOptions(
|
||||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
),
|
),
|
||||||
|
"optimizer_step": _LogOptions(
|
||||||
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
|
),
|
||||||
"on_before_zero_grad": _LogOptions(
|
"on_before_zero_grad": _LogOptions(
|
||||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
),
|
),
|
||||||
|
"optimizer_zero_grad": _LogOptions(
|
||||||
|
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||||
|
),
|
||||||
"on_init_start": None,
|
"on_init_start": None,
|
||||||
"on_init_end": None,
|
"on_init_end": None,
|
||||||
"on_fit_start": None,
|
"on_fit_start": None,
|
||||||
|
@ -160,6 +169,8 @@ class _FxValidator:
|
||||||
"configure_callbacks": None,
|
"configure_callbacks": None,
|
||||||
"on_validation_model_eval": None,
|
"on_validation_model_eval": None,
|
||||||
"on_test_model_eval": None,
|
"on_test_model_eval": None,
|
||||||
|
"on_validation_model_train": None,
|
||||||
|
"on_test_model_train": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1452,7 +1452,7 @@ class Trainer(
|
||||||
*args: Any,
|
*args: Any,
|
||||||
pl_module: Optional["pl.LightningModule"] = None,
|
pl_module: Optional["pl.LightningModule"] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
) -> Any:
|
||||||
pl_module = pl_module or self.lightning_module
|
pl_module = pl_module or self.lightning_module
|
||||||
|
|
||||||
if pl_module is None:
|
if pl_module is None:
|
||||||
|
@ -1460,7 +1460,7 @@ class Trainer(
|
||||||
|
|
||||||
fn = getattr(pl_module, hook_name)
|
fn = getattr(pl_module, hook_name)
|
||||||
if not callable(fn):
|
if not callable(fn):
|
||||||
return None
|
return
|
||||||
|
|
||||||
prev_fx_name = pl_module._current_fx_name
|
prev_fx_name = pl_module._current_fx_name
|
||||||
pl_module._current_fx_name = hook_name
|
pl_module._current_fx_name = hook_name
|
||||||
|
@ -1479,16 +1479,15 @@ class Trainer(
|
||||||
hook_name: str,
|
hook_name: str,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Optional[Any]:
|
) -> None:
|
||||||
output = None
|
|
||||||
if hook_name in ("on_init_start", "on_init_end"):
|
if hook_name in ("on_init_start", "on_init_end"):
|
||||||
# these `Callback` hooks are the only ones that do not take a lightning module.
|
# these `Callback` hooks are the only ones that do not take a lightning module.
|
||||||
# we also don't profile bc profiler hasn't been set yet
|
# we also don't profile bc profiler hasn't been set yet
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
fn = getattr(callback, hook_name)
|
fn = getattr(callback, hook_name)
|
||||||
if callable(fn):
|
if callable(fn):
|
||||||
output = fn(self, *args, **kwargs)
|
fn(self, *args, **kwargs)
|
||||||
return output
|
return
|
||||||
|
|
||||||
pl_module = self.lightning_module
|
pl_module = self.lightning_module
|
||||||
if pl_module:
|
if pl_module:
|
||||||
|
@ -1500,34 +1499,39 @@ class Trainer(
|
||||||
fn = getattr(self, hook_name)
|
fn = getattr(self, hook_name)
|
||||||
if callable(fn):
|
if callable(fn):
|
||||||
with self.profiler.profile(hook_name):
|
with self.profiler.profile(hook_name):
|
||||||
output = fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
fn = getattr(callback, hook_name)
|
fn = getattr(callback, hook_name)
|
||||||
if callable(fn):
|
if callable(fn):
|
||||||
with self.profiler.profile(hook_name):
|
with self.profiler.profile(hook_name):
|
||||||
output = fn(self, self.lightning_module, *args, **kwargs)
|
fn(self, self.lightning_module, *args, **kwargs)
|
||||||
|
|
||||||
if pl_module:
|
if pl_module:
|
||||||
# restore current_fx when nested context
|
# restore current_fx when nested context
|
||||||
pl_module._current_fx_name = prev_fx_name
|
pl_module._current_fx_name = prev_fx_name
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
# TODO: rename to _call_strategy_hook and eventually no longer need this
|
# TODO: rename to _call_strategy_hook and eventually no longer need this
|
||||||
def _call_ttp_hook(
|
def _call_ttp_hook(
|
||||||
self,
|
self,
|
||||||
hook_name: str,
|
hook_name: str,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
) -> Any:
|
||||||
|
pl_module = self.lightning_module
|
||||||
|
prev_fx_name = pl_module._current_fx_name
|
||||||
|
pl_module._current_fx_name = hook_name
|
||||||
|
|
||||||
fn = getattr(self.training_type_plugin, hook_name)
|
fn = getattr(self.training_type_plugin, hook_name)
|
||||||
if not callable(fn):
|
if not callable(fn):
|
||||||
return None
|
return
|
||||||
|
|
||||||
with self.profiler.profile(hook_name):
|
with self.profiler.profile(hook_name):
|
||||||
output = fn(*args, **kwargs)
|
output = fn(*args, **kwargs)
|
||||||
|
|
||||||
|
# restore current_fx when nested context
|
||||||
|
pl_module._current_fx_name = prev_fx_name
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# TODO: eventually no longer need this
|
# TODO: eventually no longer need this
|
||||||
|
@ -1536,15 +1540,21 @@ class Trainer(
|
||||||
hook_name: str,
|
hook_name: str,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Optional[Any]:
|
) -> Any:
|
||||||
self.lightning_module._current_fx_name = hook_name
|
pl_module = self.lightning_module
|
||||||
fn = getattr(self.training_type_plugin, hook_name)
|
prev_fx_name = pl_module._current_fx_name
|
||||||
|
pl_module._current_fx_name = hook_name
|
||||||
|
|
||||||
|
fn = getattr(self.accelerator, hook_name)
|
||||||
if not callable(fn):
|
if not callable(fn):
|
||||||
return None
|
return
|
||||||
|
|
||||||
with self.profiler.profile(hook_name):
|
with self.profiler.profile(hook_name):
|
||||||
output = fn(*args, **kwargs)
|
output = fn(*args, **kwargs)
|
||||||
|
|
||||||
|
# restore current_fx when nested context
|
||||||
|
pl_module._current_fx_name = prev_fx_name
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -31,9 +31,9 @@ from tests.models.test_hooks import get_members
|
||||||
|
|
||||||
|
|
||||||
def test_fx_validator(tmpdir):
|
def test_fx_validator(tmpdir):
|
||||||
funcs_name = sorted(get_members(Callback))
|
funcs_name = get_members(Callback)
|
||||||
|
|
||||||
callbacks_func = [
|
callbacks_func = {
|
||||||
"on_before_backward",
|
"on_before_backward",
|
||||||
"on_after_backward",
|
"on_after_backward",
|
||||||
"on_before_optimizer_step",
|
"on_before_optimizer_step",
|
||||||
|
@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
|
||||||
"on_predict_start",
|
"on_predict_start",
|
||||||
"setup",
|
"setup",
|
||||||
"teardown",
|
"teardown",
|
||||||
]
|
}
|
||||||
|
|
||||||
not_supported = [
|
not_supported = {
|
||||||
"on_before_accelerator_backend_setup",
|
"on_before_accelerator_backend_setup",
|
||||||
"on_fit_end",
|
"on_fit_end",
|
||||||
"on_fit_start",
|
"on_fit_start",
|
||||||
|
@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
|
||||||
"on_validation_end",
|
"on_validation_end",
|
||||||
"setup",
|
"setup",
|
||||||
"teardown",
|
"teardown",
|
||||||
]
|
}
|
||||||
|
|
||||||
assert funcs_name == sorted(
|
# Detected new callback function. Need to add its logging permission to FxValidator and update this test
|
||||||
callbacks_func
|
assert funcs_name == callbacks_func
|
||||||
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
|
|
||||||
|
|
||||||
validator = _FxValidator()
|
validator = _FxValidator()
|
||||||
|
|
||||||
|
@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
|
||||||
"prepare_data": "You can't",
|
"prepare_data": "You can't",
|
||||||
"configure_callbacks": "You can't",
|
"configure_callbacks": "You can't",
|
||||||
"on_validation_model_eval": "You can't",
|
"on_validation_model_eval": "You can't",
|
||||||
|
"on_validation_model_train": "You can't",
|
||||||
"summarize": "not managed by the `Trainer",
|
"summarize": "not managed by the `Trainer",
|
||||||
}
|
}
|
||||||
model = HookedModel(not_supported)
|
model = HookedModel(not_supported)
|
||||||
|
@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
|
||||||
"on_test_dataloader": "You can't",
|
"on_test_dataloader": "You can't",
|
||||||
"test_dataloader": "You can't",
|
"test_dataloader": "You can't",
|
||||||
"on_test_model_eval": "You can't",
|
"on_test_model_eval": "You can't",
|
||||||
|
"on_test_model_train": "You can't",
|
||||||
"on_test_end": "You can't",
|
"on_test_end": "You can't",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,9 +50,12 @@ def test_default_level_for_hooks_that_support_logging():
|
||||||
trainer.state.stage = RunningStage.TRAINING
|
trainer.state.stage = RunningStage.TRAINING
|
||||||
hooks = [
|
hooks = [
|
||||||
"on_before_backward",
|
"on_before_backward",
|
||||||
|
"backward",
|
||||||
"on_after_backward",
|
"on_after_backward",
|
||||||
"on_before_optimizer_step",
|
"on_before_optimizer_step",
|
||||||
|
"optimizer_step",
|
||||||
"on_before_zero_grad",
|
"on_before_zero_grad",
|
||||||
|
"optimizer_zero_grad",
|
||||||
"training_step",
|
"training_step",
|
||||||
"training_step_end",
|
"training_step_end",
|
||||||
"on_batch_start",
|
"on_batch_start",
|
||||||
|
|
Loading…
Reference in New Issue