Mark certain Trainer APIs as protected (#7420)

This commit is contained in:
ananthsub 2021-05-11 02:53:41 -07:00 committed by GitHub
parent ad9118f04a
commit fdf50a5e4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 33 deletions

View File

@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
- Refactored Loops

View File

@ -706,8 +706,8 @@ class Trainer(
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.call_configure_sharded_model(model) # allow user to setup in model sharded environment
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
# ----------------------------
@ -719,7 +719,7 @@ class Trainer(
| ||
create accelerator ||
| ||
{self.dispatch} ||
{self._dispatch} ||
| || LIGHTNING
{self.accelerator.start_training} ||
or {self.accelerator.start_evaluating} ||
@ -727,13 +727,13 @@ class Trainer(
| ||
{self.run_stage} ||
| || DIRECTION
{self.run_train} ||
or {self.run_evaluation} ||
or {self.run_predict} ||
{self._run_train} ||
or {self._run_evaluation} ||
or {self._run_predict} ||
| ||
results \/
This is used to guide readers to the core loops: train, test, predict.
{self.run_predict} is the simplest to understand, use `Go to Definition` to read it :)
{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
Search for `start_training` or `start_evaluating` or `start_predicting` in
`pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions.
""" # noqa: W605
@ -746,13 +746,13 @@ class Trainer(
self.call_hook("on_fit_start")
# plugin will setup fitting (e.g. ddp will launch child processes)
self.pre_dispatch()
self._pre_dispatch()
# dispatch `start_training` or `start_evaluating` or `start_predicting`
self.dispatch()
self._dispatch()
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
self.post_dispatch()
self._post_dispatch()
# ----------------------------
# POST-Training CLEAN UP
@ -762,7 +762,7 @@ class Trainer(
self.call_hook('on_fit_end')
# teardown
self.call_teardown_hook(model)
self._call_teardown_hook(model)
if self.state.status != TrainerStatus.INTERRUPTED:
self.state.status = TrainerStatus.FINISHED
@ -770,7 +770,7 @@ class Trainer(
return self.accelerator.results
def pre_dispatch(self):
def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
# log hyper-parameters
@ -780,11 +780,11 @@ class Trainer(
self.logger.log_graph(self.lightning_module)
self.logger.save()
def post_dispatch(self):
def _post_dispatch(self):
self.accelerator.post_dispatch(self)
self.accelerator.teardown()
def dispatch(self):
def _dispatch(self):
if self.evaluating:
self.accelerator.start_evaluating(self)
elif self.predicting:
@ -797,10 +797,10 @@ class Trainer(
self.profile_connector.setup()
if self.evaluating:
return self.run_evaluate()
return self._run_evaluate()
if self.predicting:
return self.run_predict()
return self.run_train()
return self._run_predict()
return self._run_train()
def _pre_training_routine(self):
# wait for all to join if on distributed
@ -829,13 +829,13 @@ class Trainer(
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
def run_train(self) -> None:
def _run_train(self) -> None:
self._pre_training_routine()
if not self.is_global_zero and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
self.run_sanity_check(self.lightning_module)
self._run_sanity_check(self.lightning_module)
self.checkpoint_connector.has_trained = False
@ -904,10 +904,10 @@ class Trainer(
self.state.stage = None
raise
def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer.run_evaluation()` was called but the running stage is set to {self.state.stage}."
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
" This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning
)
self.validating = True
@ -965,7 +965,7 @@ class Trainer(
self.logger_connector.log_evaluation_step_metrics()
# track epoch level outputs
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)
dl_outputs = self._track_output_for_epoch_end(dl_outputs, output)
# store batch level output per dataloader
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
@ -1016,7 +1016,7 @@ class Trainer(
return eval_loop_results
def track_output_for_epoch_end(self, outputs, output):
def _track_output_for_epoch_end(self, outputs, output):
if output is not None:
if isinstance(output, Result):
output = output.detach()
@ -1029,14 +1029,14 @@ class Trainer(
outputs.append(output)
return outputs
def run_evaluate(self) -> _EVALUATE_OUTPUT:
def _run_evaluate(self) -> _EVALUATE_OUTPUT:
if not self.is_global_zero and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
assert self.evaluating
with self.profiler.profile(f"run_{self.state.stage}_evaluation"):
eval_loop_results = self.run_evaluation()
eval_loop_results = self._run_evaluation()
# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
@ -1047,7 +1047,7 @@ class Trainer(
return eval_loop_results
def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
@ -1085,7 +1085,7 @@ class Trainer(
return results
def run_sanity_check(self, ref_model):
def _run_sanity_check(self, ref_model):
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
@ -1099,7 +1099,7 @@ class Trainer(
self.on_sanity_check_start()
# run eval step
self.run_evaluation()
self._run_evaluation()
self.on_sanity_check_end()
@ -1145,7 +1145,7 @@ class Trainer(
)
return ckpt_path
def call_setup_hook(self, model: LightningModule) -> None:
def _call_setup_hook(self, model: LightningModule) -> None:
fn = self.state.fn._setup_fn
self.accelerator.barrier("pre_setup")
@ -1157,7 +1157,7 @@ class Trainer(
self.accelerator.barrier("post_setup")
def call_configure_sharded_model(self, model: LightningModule) -> None:
def _call_configure_sharded_model(self, model: LightningModule) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.
@ -1170,7 +1170,7 @@ class Trainer(
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False
def call_teardown_hook(self, model: LightningModule) -> None:
def _call_teardown_hook(self, model: LightningModule) -> None:
fn = self.state.fn._setup_fn
if self.datamodule is not None:

View File

@ -513,7 +513,7 @@ class TrainLoop:
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.validating = True
self.trainer.run_evaluation()
self.trainer._run_evaluation()
self.trainer.training = True
val_loop_called = True
@ -572,7 +572,7 @@ class TrainLoop:
if should_check_val:
self.trainer.validating = True
self.trainer.run_evaluation(on_epoch=True)
self.trainer._run_evaluation(on_epoch=True)
self.trainer.training = True
# increment the global step once