mark `FitLoop.should_accumulate` as protected (#9515)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-09-15 15:32:14 +02:00 committed by GitHub
parent 200ed9eb9f
commit b9fa69ea57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 22 deletions

View File

@ -78,6 +78,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))
* Marked `OptimizerLoop.backward` as protected ([#9514](https://github.com/PyTorchLightning/pytorch-lightning/pull/9514))
* Marked `FitLoop.should_accumulate` as protected ([#9515](https://github.com/PyTorchLightning/pytorch-lightning/pull/9515))
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

View File

@ -135,19 +135,6 @@ class FitLoop(Loop):
return self.epoch_loop.val_loop._results
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
@staticmethod
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
is enabled.
Args:
max_value: the value to check
Returns:
whether the limit for this value should be enabled
"""
return max_value not in (None, -1)
@property
def done(self) -> bool:
"""Evaluates when to leave the loop.
@ -254,10 +241,6 @@ class FitLoop(Loop):
# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()
def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
def teardown(self) -> None:
self.epoch_loop.teardown()
@ -270,3 +253,20 @@ class FitLoop(Loop):
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
@staticmethod
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
is enabled.
Args:
max_value: the value to check
Returns:
whether the limit for this value should be enabled
"""
return max_value not in (None, -1)

View File

@ -233,7 +233,7 @@ class OptimizerLoop(Loop):
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
if not self.trainer.fit_loop.should_accumulate():
if not self.trainer.fit_loop._should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
@ -260,7 +260,7 @@ class OptimizerLoop(Loop):
if (
# when the training type plugin handles accumulation, we want to always call the optimizer step
not self.trainer.training_type_plugin.handles_gradient_accumulation
and self.trainer.fit_loop.should_accumulate()
and self.trainer.fit_loop._should_accumulate()
):
# For gradient accumulation

View File

@ -218,7 +218,7 @@ class LoggerConnector:
self._split_idx = split_idx
def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return
self._log_gpus_metrics()

View File

@ -256,7 +256,7 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
def training_step(self, *args):
self.log("foo", 1, on_step=True, on_epoch=True)
if not self.trainer.fit_loop.should_accumulate():
if not self.trainer.fit_loop._should_accumulate():
if self.trainer.logger_connector.should_update_logs:
self.indexes.append(self.trainer.global_step)
return super().training_step(*args)

View File

@ -559,7 +559,7 @@ def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir):
opt.step(closure=optimizer_closure)
weight_after = self.layer.weight.clone()
if not self.trainer.fit_loop.should_accumulate():
if not self.trainer.fit_loop._should_accumulate():
assert not torch.equal(weight_before, weight_after)
else:
assert self.layer.weight.grad is not None