diff --git a/CHANGELOG.md b/CHANGELOG.md index 0271ee6115..f96005f322 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266)) * Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424)) * Marked `OptimizerLoop.backward` as protected ([#9514](https://github.com/PyTorchLightning/pytorch-lightning/pull/9514)) + * Marked `FitLoop.should_accumulate` as protected ([#9515](https://github.com/PyTorchLightning/pytorch-lightning/pull/9515)) - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 684b3522c1..3e9917a551 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -135,19 +135,6 @@ class FitLoop(Loop): return self.epoch_loop.val_loop._results raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope") - @staticmethod - def _is_max_limit_enabled(max_value: Optional[int]) -> bool: - """Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps - is enabled. - - Args: - max_value: the value to check - - Returns: - whether the limit for this value should be enabled - """ - return max_value not in (None, -1) - @property def done(self) -> bool: """Evaluates when to leave the loop. @@ -254,10 +241,6 @@ class FitLoop(Loop): # give accelerators a chance to finish self.trainer.accelerator.on_train_end() - def should_accumulate(self) -> bool: - """Whether the gradients should be accumulated.""" - return self.epoch_loop._should_accumulate() - def teardown(self) -> None: self.epoch_loop.teardown() @@ -270,3 +253,20 @@ class FitLoop(Loop): def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) + + def _should_accumulate(self) -> bool: + """Whether the gradients should be accumulated.""" + return self.epoch_loop._should_accumulate() + + @staticmethod + def _is_max_limit_enabled(max_value: Optional[int]) -> bool: + """Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps + is enabled. + + Args: + max_value: the value to check + + Returns: + whether the limit for this value should be enabled + """ + return max_value not in (None, -1) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index f5352c42b2..e3b1083b04 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -233,7 +233,7 @@ class OptimizerLoop(Loop): """ self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs) - if not self.trainer.fit_loop.should_accumulate(): + if not self.trainer.fit_loop._should_accumulate(): # track gradients grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer) if grad_norm_dict: @@ -260,7 +260,7 @@ class OptimizerLoop(Loop): if ( # when the training type plugin handles accumulation, we want to always call the optimizer step not self.trainer.training_type_plugin.handles_gradient_accumulation - and self.trainer.fit_loop.should_accumulate() + and self.trainer.fit_loop._should_accumulate() ): # For gradient accumulation diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c3356c1392..2e6b607784 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -218,7 +218,7 @@ class LoggerConnector: self._split_idx = split_idx def update_train_step_metrics(self) -> None: - if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: + if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization: return self._log_gpus_metrics() diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 027a29d94f..5e82c79463 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -256,7 +256,7 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir): def training_step(self, *args): self.log("foo", 1, on_step=True, on_epoch=True) - if not self.trainer.fit_loop.should_accumulate(): + if not self.trainer.fit_loop._should_accumulate(): if self.trainer.logger_connector.should_update_logs: self.indexes.append(self.trainer.global_step) return super().training_step(*args) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 38015c7c11..544e32c089 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -559,7 +559,7 @@ def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): opt.step(closure=optimizer_closure) weight_after = self.layer.weight.clone() - if not self.trainer.fit_loop.should_accumulate(): + if not self.trainer.fit_loop._should_accumulate(): assert not torch.equal(weight_before, weight_after) else: assert self.layer.weight.grad is not None