Mark certain Trainer APIs as protected (#7420)
This commit is contained in:
parent
ad9118f04a
commit
fdf50a5e4b
|
@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
|
||||
|
||||
|
||||
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
|
||||
|
||||
|
||||
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
|
||||
- Refactored Loops
|
||||
|
|
|
@ -706,8 +706,8 @@ class Trainer(
|
|||
self.call_hook("on_before_accelerator_backend_setup", model)
|
||||
self.accelerator.connect(model)
|
||||
self.accelerator.setup_environment()
|
||||
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
|
||||
self.call_configure_sharded_model(model) # allow user to setup in model sharded environment
|
||||
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
|
||||
self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
|
||||
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
|
||||
|
||||
# ----------------------------
|
||||
|
@ -719,7 +719,7 @@ class Trainer(
|
|||
| ||
|
||||
create accelerator ||
|
||||
| ||
|
||||
{self.dispatch} ||
|
||||
{self._dispatch} ||
|
||||
| || LIGHTNING
|
||||
{self.accelerator.start_training} ||
|
||||
or {self.accelerator.start_evaluating} ||
|
||||
|
@ -727,13 +727,13 @@ class Trainer(
|
|||
| ||
|
||||
{self.run_stage} ||
|
||||
| || DIRECTION
|
||||
{self.run_train} ||
|
||||
or {self.run_evaluation} ||
|
||||
or {self.run_predict} ||
|
||||
{self._run_train} ||
|
||||
or {self._run_evaluation} ||
|
||||
or {self._run_predict} ||
|
||||
| ||
|
||||
results \/
|
||||
This is used to guide readers to the core loops: train, test, predict.
|
||||
{self.run_predict} is the simplest to understand, use `Go to Definition` to read it :)
|
||||
{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
|
||||
Search for `start_training` or `start_evaluating` or `start_predicting` in
|
||||
`pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions.
|
||||
""" # noqa: W605
|
||||
|
@ -746,13 +746,13 @@ class Trainer(
|
|||
self.call_hook("on_fit_start")
|
||||
|
||||
# plugin will setup fitting (e.g. ddp will launch child processes)
|
||||
self.pre_dispatch()
|
||||
self._pre_dispatch()
|
||||
|
||||
# dispatch `start_training` or `start_evaluating` or `start_predicting`
|
||||
self.dispatch()
|
||||
self._dispatch()
|
||||
|
||||
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
|
||||
self.post_dispatch()
|
||||
self._post_dispatch()
|
||||
|
||||
# ----------------------------
|
||||
# POST-Training CLEAN UP
|
||||
|
@ -762,7 +762,7 @@ class Trainer(
|
|||
self.call_hook('on_fit_end')
|
||||
|
||||
# teardown
|
||||
self.call_teardown_hook(model)
|
||||
self._call_teardown_hook(model)
|
||||
|
||||
if self.state.status != TrainerStatus.INTERRUPTED:
|
||||
self.state.status = TrainerStatus.FINISHED
|
||||
|
@ -770,7 +770,7 @@ class Trainer(
|
|||
|
||||
return self.accelerator.results
|
||||
|
||||
def pre_dispatch(self):
|
||||
def _pre_dispatch(self):
|
||||
self.accelerator.pre_dispatch(self)
|
||||
|
||||
# log hyper-parameters
|
||||
|
@ -780,11 +780,11 @@ class Trainer(
|
|||
self.logger.log_graph(self.lightning_module)
|
||||
self.logger.save()
|
||||
|
||||
def post_dispatch(self):
|
||||
def _post_dispatch(self):
|
||||
self.accelerator.post_dispatch(self)
|
||||
self.accelerator.teardown()
|
||||
|
||||
def dispatch(self):
|
||||
def _dispatch(self):
|
||||
if self.evaluating:
|
||||
self.accelerator.start_evaluating(self)
|
||||
elif self.predicting:
|
||||
|
@ -797,10 +797,10 @@ class Trainer(
|
|||
self.profile_connector.setup()
|
||||
|
||||
if self.evaluating:
|
||||
return self.run_evaluate()
|
||||
return self._run_evaluate()
|
||||
if self.predicting:
|
||||
return self.run_predict()
|
||||
return self.run_train()
|
||||
return self._run_predict()
|
||||
return self._run_train()
|
||||
|
||||
def _pre_training_routine(self):
|
||||
# wait for all to join if on distributed
|
||||
|
@ -829,13 +829,13 @@ class Trainer(
|
|||
self.on_pretrain_routine_end()
|
||||
ref_model.on_pretrain_routine_end()
|
||||
|
||||
def run_train(self) -> None:
|
||||
def _run_train(self) -> None:
|
||||
self._pre_training_routine()
|
||||
|
||||
if not self.is_global_zero and self.progress_bar_callback is not None:
|
||||
self.progress_bar_callback.disable()
|
||||
|
||||
self.run_sanity_check(self.lightning_module)
|
||||
self._run_sanity_check(self.lightning_module)
|
||||
|
||||
self.checkpoint_connector.has_trained = False
|
||||
|
||||
|
@ -904,10 +904,10 @@ class Trainer(
|
|||
self.state.stage = None
|
||||
raise
|
||||
|
||||
def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
|
||||
def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
|
||||
if not (self.evaluating or self.sanity_checking):
|
||||
rank_zero_warn(
|
||||
f"`trainer.run_evaluation()` was called but the running stage is set to {self.state.stage}."
|
||||
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
|
||||
" This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning
|
||||
)
|
||||
self.validating = True
|
||||
|
@ -965,7 +965,7 @@ class Trainer(
|
|||
self.logger_connector.log_evaluation_step_metrics()
|
||||
|
||||
# track epoch level outputs
|
||||
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)
|
||||
dl_outputs = self._track_output_for_epoch_end(dl_outputs, output)
|
||||
|
||||
# store batch level output per dataloader
|
||||
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
|
||||
|
@ -1016,7 +1016,7 @@ class Trainer(
|
|||
|
||||
return eval_loop_results
|
||||
|
||||
def track_output_for_epoch_end(self, outputs, output):
|
||||
def _track_output_for_epoch_end(self, outputs, output):
|
||||
if output is not None:
|
||||
if isinstance(output, Result):
|
||||
output = output.detach()
|
||||
|
@ -1029,14 +1029,14 @@ class Trainer(
|
|||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
def run_evaluate(self) -> _EVALUATE_OUTPUT:
|
||||
def _run_evaluate(self) -> _EVALUATE_OUTPUT:
|
||||
if not self.is_global_zero and self.progress_bar_callback is not None:
|
||||
self.progress_bar_callback.disable()
|
||||
|
||||
assert self.evaluating
|
||||
|
||||
with self.profiler.profile(f"run_{self.state.stage}_evaluation"):
|
||||
eval_loop_results = self.run_evaluation()
|
||||
eval_loop_results = self._run_evaluation()
|
||||
|
||||
# remove the tensors from the eval results
|
||||
for i, result in enumerate(eval_loop_results):
|
||||
|
@ -1047,7 +1047,7 @@ class Trainer(
|
|||
|
||||
return eval_loop_results
|
||||
|
||||
def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
|
||||
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
|
||||
# prepare dataloaders
|
||||
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
|
||||
|
||||
|
@ -1085,7 +1085,7 @@ class Trainer(
|
|||
|
||||
return results
|
||||
|
||||
def run_sanity_check(self, ref_model):
|
||||
def _run_sanity_check(self, ref_model):
|
||||
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
|
||||
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
|
||||
|
||||
|
@ -1099,7 +1099,7 @@ class Trainer(
|
|||
self.on_sanity_check_start()
|
||||
|
||||
# run eval step
|
||||
self.run_evaluation()
|
||||
self._run_evaluation()
|
||||
|
||||
self.on_sanity_check_end()
|
||||
|
||||
|
@ -1145,7 +1145,7 @@ class Trainer(
|
|||
)
|
||||
return ckpt_path
|
||||
|
||||
def call_setup_hook(self, model: LightningModule) -> None:
|
||||
def _call_setup_hook(self, model: LightningModule) -> None:
|
||||
fn = self.state.fn._setup_fn
|
||||
|
||||
self.accelerator.barrier("pre_setup")
|
||||
|
@ -1157,7 +1157,7 @@ class Trainer(
|
|||
|
||||
self.accelerator.barrier("post_setup")
|
||||
|
||||
def call_configure_sharded_model(self, model: LightningModule) -> None:
|
||||
def _call_configure_sharded_model(self, model: LightningModule) -> None:
|
||||
# Call configure sharded model hook if accelerator requests. In some cases
|
||||
# we will not call the hook; the hook has initialized the sharded model for example.
|
||||
|
||||
|
@ -1170,7 +1170,7 @@ class Trainer(
|
|||
model.call_configure_sharded_model_hook = True
|
||||
self.accelerator.call_configure_sharded_model_hook = False
|
||||
|
||||
def call_teardown_hook(self, model: LightningModule) -> None:
|
||||
def _call_teardown_hook(self, model: LightningModule) -> None:
|
||||
fn = self.state.fn._setup_fn
|
||||
|
||||
if self.datamodule is not None:
|
||||
|
|
|
@ -513,7 +513,7 @@ class TrainLoop:
|
|||
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
self.trainer.run_evaluation()
|
||||
self.trainer._run_evaluation()
|
||||
self.trainer.training = True
|
||||
val_loop_called = True
|
||||
|
||||
|
@ -572,7 +572,7 @@ class TrainLoop:
|
|||
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
self.trainer.run_evaluation(on_epoch=True)
|
||||
self.trainer._run_evaluation(on_epoch=True)
|
||||
self.trainer.training = True
|
||||
|
||||
# increment the global step once
|
||||
|
|
Loading…
Reference in New Issue