mark `FitLoop.should_accumulate` as protected (#9515)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
200ed9eb9f
commit
b9fa69ea57
|
@ -78,6 +78,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
|
||||
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))
|
||||
* Marked `OptimizerLoop.backward` as protected ([#9514](https://github.com/PyTorchLightning/pytorch-lightning/pull/9514))
|
||||
* Marked `FitLoop.should_accumulate` as protected ([#9515](https://github.com/PyTorchLightning/pytorch-lightning/pull/9515))
|
||||
|
||||
|
||||
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
|
||||
|
|
|
@ -135,19 +135,6 @@ class FitLoop(Loop):
|
|||
return self.epoch_loop.val_loop._results
|
||||
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
|
||||
|
||||
@staticmethod
|
||||
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
|
||||
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
|
||||
is enabled.
|
||||
|
||||
Args:
|
||||
max_value: the value to check
|
||||
|
||||
Returns:
|
||||
whether the limit for this value should be enabled
|
||||
"""
|
||||
return max_value not in (None, -1)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Evaluates when to leave the loop.
|
||||
|
@ -254,10 +241,6 @@ class FitLoop(Loop):
|
|||
# give accelerators a chance to finish
|
||||
self.trainer.accelerator.on_train_end()
|
||||
|
||||
def should_accumulate(self) -> bool:
|
||||
"""Whether the gradients should be accumulated."""
|
||||
return self.epoch_loop._should_accumulate()
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.epoch_loop.teardown()
|
||||
|
||||
|
@ -270,3 +253,20 @@ class FitLoop(Loop):
|
|||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
|
||||
|
||||
def _should_accumulate(self) -> bool:
|
||||
"""Whether the gradients should be accumulated."""
|
||||
return self.epoch_loop._should_accumulate()
|
||||
|
||||
@staticmethod
|
||||
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
|
||||
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
|
||||
is enabled.
|
||||
|
||||
Args:
|
||||
max_value: the value to check
|
||||
|
||||
Returns:
|
||||
whether the limit for this value should be enabled
|
||||
"""
|
||||
return max_value not in (None, -1)
|
||||
|
|
|
@ -233,7 +233,7 @@ class OptimizerLoop(Loop):
|
|||
"""
|
||||
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
|
||||
|
||||
if not self.trainer.fit_loop.should_accumulate():
|
||||
if not self.trainer.fit_loop._should_accumulate():
|
||||
# track gradients
|
||||
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
|
||||
if grad_norm_dict:
|
||||
|
@ -260,7 +260,7 @@ class OptimizerLoop(Loop):
|
|||
if (
|
||||
# when the training type plugin handles accumulation, we want to always call the optimizer step
|
||||
not self.trainer.training_type_plugin.handles_gradient_accumulation
|
||||
and self.trainer.fit_loop.should_accumulate()
|
||||
and self.trainer.fit_loop._should_accumulate()
|
||||
):
|
||||
# For gradient accumulation
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ class LoggerConnector:
|
|||
self._split_idx = split_idx
|
||||
|
||||
def update_train_step_metrics(self) -> None:
|
||||
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
|
||||
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
|
||||
return
|
||||
|
||||
self._log_gpus_metrics()
|
||||
|
|
|
@ -256,7 +256,7 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
|
|||
|
||||
def training_step(self, *args):
|
||||
self.log("foo", 1, on_step=True, on_epoch=True)
|
||||
if not self.trainer.fit_loop.should_accumulate():
|
||||
if not self.trainer.fit_loop._should_accumulate():
|
||||
if self.trainer.logger_connector.should_update_logs:
|
||||
self.indexes.append(self.trainer.global_step)
|
||||
return super().training_step(*args)
|
||||
|
|
|
@ -559,7 +559,7 @@ def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir):
|
|||
opt.step(closure=optimizer_closure)
|
||||
|
||||
weight_after = self.layer.weight.clone()
|
||||
if not self.trainer.fit_loop.should_accumulate():
|
||||
if not self.trainer.fit_loop._should_accumulate():
|
||||
assert not torch.equal(weight_before, weight_after)
|
||||
else:
|
||||
assert self.layer.weight.grad is not None
|
||||
|
|
Loading…
Reference in New Issue