From 02152c17299eaacb26748a858d1a3545b419b93a Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 7 Dec 2020 12:55:49 +0000 Subject: [PATCH] Simplify optimization Logic (#4984) * Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization * debug * Revert "debug" This reverts commit ccca6b6b * Expose manual reduce for automatic optimization * Add input arguments * Enable parity test * clean imports * Expose hook after to ensure we reset * Fix naming * add * fix test * uniformize optimizer logic * resolve test * resovle flake8 * resolve amp bug * update tests * remove bug * remove optimizer_step in accelerators * typo * update lightning optimizer * set doesn't work with ddp_spawn * resolve flake8 * update threshold * ignore pyright * correct codeFactor * remove useless if * remove zer_grad function * simplify step * remove typo * resolve bug * Apply suggestions from code review * update on comments * resolve bugs * remove tests * Update pytorch_lightning/trainer/configuration_validator.py Co-authored-by: Rohit Gupta * simplify testing * add more tests Co-authored-by: SeanNaren Co-authored-by: Jirka Borovec Co-authored-by: Rohit Gupta --- .drone.yml | 7 +- pytorch_lightning/accelerators/accelerator.py | 41 +--- .../accelerators/tpu_accelerator.py | 27 +-- .../gradient_accumulation_scheduler.py | 3 + pytorch_lightning/core/lightning.py | 14 +- pytorch_lightning/core/optimizer.py | 100 +++++--- pytorch_lightning/overrides/data_parallel.py | 2 +- pytorch_lightning/plugins/apex.py | 4 + pytorch_lightning/plugins/ddp_plugin.py | 5 +- pytorch_lightning/plugins/native_amp.py | 5 + pytorch_lightning/plugins/sharded_plugin.py | 2 +- .../trainer/configuration_validator.py | 33 ++- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 3 +- pytorch_lightning/trainer/training_loop.py | 53 ++--- tests/callbacks/test_callbacks.py | 3 +- tests/core/test_lightning_module.py | 102 ++++++++ tests/core/test_lightning_optimizer.py | 219 +++++++++++++++++- tests/plugins/test_amp_plugin.py | 14 +- tests/special_tests.sh | 17 ++ .../optimization/test_manual_optimization.py | 40 ++++ 21 files changed, 554 insertions(+), 142 deletions(-) create mode 100644 tests/core/test_lightning_module.py create mode 100644 tests/special_tests.sh diff --git a/.drone.yml b/.drone.yml index 886f3a429b..c87130844c 100644 --- a/.drone.yml +++ b/.drone.yml @@ -39,10 +39,11 @@ steps: # todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved - pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed - pip list - - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8 - - python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0 # --flake8 - #- cd docs; make doctest; make coverage + - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8 + # Running special tests + - sh tests/special_tests.sh - coverage report + - python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0 # see: https://docs.codecov.io/docs/merging-reports - codecov --token $CODECOV_TOKEN --flags=gpu,pytest --name="GPU-coverage" --env=linux --build $DRONE_BUILD_NUMBER --commit $DRONE_COMMIT # --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG --pr $DRONE_PULL_REQUEST diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0f61c53ffa..105c3a1309 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -16,14 +16,15 @@ from enum import Enum from typing import Any, Optional, Union import torch +import torch.distributed as torch_distrib from torch.optim import Optimizer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.core.lightning import LightningModule -import torch.distributed as torch_distrib if torch.distributed.is_available(): from torch.distributed import ReduceOp @@ -98,40 +99,6 @@ class Accelerator(object): closure_loss = closure_loss.detach() return closure_loss - def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs): - model_ref = self.trainer.get_model() - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - automatic_optimization = self.trainer.train_loop.automatic_optimization - - # native amp + lbfgs is a no go right now - if using_native_amp and is_lbfgs: - raise MisconfigurationException( - 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli') - - # model hook - model_ref.optimizer_step( - epoch=self.trainer.current_epoch, - batch_idx=batch_idx, - optimizer=optimizer, - optimizer_idx=opt_idx, - optimizer_closure=lambda_closure, - on_tpu=False, # TPUAccelerator class sets this as True - using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs, - *args, - **kwargs, - ) - - # scale when native amp - if automatic_optimization and using_native_amp: - self.trainer.scaler.update() - - def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): - model_ref = self.trainer.get_model() - model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - def clip_gradients(self, optimizer, clip_val=None): # use the trainer's clip val if none passed grad_clip_val = self.trainer.gradient_clip_val @@ -160,7 +127,7 @@ class Accelerator(object): return self.trainer.should_stop def setup_optimizers(self, model): - if self.trainer.testing is True: + if self.trainer.testing: return optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 6da5150d1f..cd6b99fa64 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -23,7 +23,14 @@ from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.utilities import ( + TPU_AVAILABLE, + move_data_to_device, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -245,24 +252,6 @@ class TPUAccelerator(Accelerator): return closure_loss - def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs): - model_ref = self.trainer.get_model() - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - - # model hook - model_ref.optimizer_step( - epoch=self.trainer.current_epoch, - batch_idx=batch_idx, - optimizer=optimizer, - optimizer_idx=opt_idx, - optimizer_closure=lambda_closure, - on_tpu=True, - using_native_amp=False, - using_lbfgs=is_lbfgs, - *args, - **kwargs, - ) - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): # this code is a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 7b723c3fc9..bc7e9eba0a 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -67,6 +67,9 @@ class GradientAccumulationScheduler(Callback): self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) + def going_to_accumulate_grad_batches(self): + return any([v > 1 for v in self.scheduling.values()]) + def on_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c33297934e..f1a0c725e2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -33,6 +33,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn @@ -1236,15 +1237,10 @@ class LightningModule( model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself. """ - if on_tpu and TPU_AVAILABLE: - xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs}) - - elif self.trainer.amp_backend is not None: - self.trainer.precision_connector.backend.optimizer_step( - self.trainer, optimizer, optimizer_closure) - - else: - optimizer.step(closure=optimizer_closure, *args, **kwargs) + if not isinstance(optimizer, LightningOptimizer): + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) + optimizer.step(closure=optimizer_closure, *args, **kwargs) def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index f8f6a7b6c0..f07f467810 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -35,7 +35,7 @@ def do_nothing_closure(): class LightningOptimizer: """ This class is used to wrap the user optimizers and handle properly - the backward and optimizer_step logic across accelerators, AMP, accumulated_grad_batches + the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches """ def __init__(self, optimizer: Optimizer, @@ -60,17 +60,35 @@ class LightningOptimizer: self._trainer = None self._optimizer = optimizer self._accumulate_grad_batches = accumulate_grad_batches - self._use_accumulate_grad_batches_from_trainer = accumulate_grad_batches is None + self._automatic_optimization = None + self._optimizer_idx = None + + @property + def accumulate_grad_batches(self): + return self._accumulate_grad_batches + + @accumulate_grad_batches.setter + def accumulate_grad_batches(self, accumulate_grad_batches): + self._accumulate_grad_batches = accumulate_grad_batches def _on_trainer_init(self, trainer): self._trainer = proxy(trainer) + self._automatic_optimization = trainer.train_loop.automatic_optimization + for opt_idx, opt in enumerate(trainer.optimizers): + if opt == self._optimizer: + self._optimizer_idx = opt_idx + break + + @classmethod + def to_lightning_optimizer(cls, optimizer, trainer): + optimizer = cls(optimizer) + optimizer._on_trainer_init(trainer) + return optimizer def _accumulated_batches_reached(self): - if self._use_accumulate_grad_batches_from_trainer: - accumulate_grad_batches = self._trainer.accumulate_grad_batches - else: - accumulate_grad_batches = self._accumulate_grad_batches - return (self._trainer.batch_idx + 1) % accumulate_grad_batches == 0 + if self.accumulate_grad_batches is None: + return self._trainer.train_loop._accumulated_batches_reached() + return (self._trainer.batch_idx + 1) % self.accumulate_grad_batches == 0 @property def _should_accumulate(self): @@ -79,6 +97,45 @@ class LightningOptimizer: is_final_batch = self._trainer.train_loop._num_training_batches_reached() return not (accumulation_done or is_final_batch) + def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): + trainer = self._trainer + optimizer = self._optimizer + model = trainer.get_model() + + if trainer.on_tpu: + with trainer.profiler.profile(profiler_name): + xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) + + elif trainer.amp_backend is not None: + trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) + + else: + with trainer.profiler.profile(profiler_name): + optimizer.step(closure=closure, *args, **kwargs) + + trainer.train_loop.on_before_zero_grad(self) + + model.optimizer_zero_grad( + trainer.current_epoch, + trainer.batch_idx, + optimizer, + self._optimizer_idx + ) + + def _check_make_optimizer_step(self, make_optimizer_step: Optional[bool]) -> bool: + if make_optimizer_step is not None and self._trainer.overriden_optimizer_zero_grad: + raise MisconfigurationException( + "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed.") + + if self._trainer.train_loop.automatic_optimization: + if self._trainer.overriden_optimizer_step and self._trainer.overriden_optimizer_zero_grad: + return True + + if make_optimizer_step is None: + make_optimizer_step = not self._should_accumulate + + return make_optimizer_step + def step(self, *args, closure: Optional[Callable] = None, make_optimizer_step: Optional[bool] = None, **kwargs): """ Call this directly from your training_step when doing optimizations manually. @@ -173,40 +230,23 @@ class LightningOptimizer: # Trainer(accumulate_grad_batches=x) opt_dis.step(closure=optimizer_closure, make_optimizer_step=True) """ - profiler_name = "optimizer_step_and_closure" + profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" if closure is None: closure = do_nothing_closure - profile_name = "optimizer_step" + profile_name = f"optimizer_step_{self._optimizer_idx}" else: if not isinstance(closure, types.FunctionType): raise MisconfigurationException("When closure is provided, it should be a function") - if make_optimizer_step is None: - make_optimizer_step = not self._should_accumulate - - trainer = self._trainer - optimizer = self._optimizer + make_optimizer_step = self._check_make_optimizer_step(make_optimizer_step) if make_optimizer_step: - if trainer.on_tpu: - with trainer.profiler.profile(profiler_name): - xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) - - elif trainer.amp_backend is not None: - trainer.precision_connector.backend.optimizer_step( - trainer, optimizer, closure) - - else: - with trainer.profiler.profile(profiler_name): - optimizer.step(closure=closure, *args, **kwargs) - - # perform zero grad - optimizer.zero_grad() + self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) else: # make sure to call optimizer_closure when accumulating - with trainer.profiler.profile("closure"): - with trainer.train_loop.block_ddp_sync_behaviour(): + with self._trainer.profiler.profile(f"closure_{self._optimizer_idx}"): + with self._trainer.train_loop.block_ddp_sync_behaviour(): closure() def __repr__(self): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 94cbd18781..602badca2c 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -14,7 +14,7 @@ import itertools import threading -from collections.abc import Mapping, Iterable +from collections.abc import Iterable, Mapping from itertools import chain import torch diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 64d77ffc87..085d0e729d 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -137,5 +137,9 @@ class ApexPlugin(PrecisionPlugin): # TODO: pass the closure to the step ASAP with trainer.profiler.profile("closure"): closure() + + if not self.trainer.train_loop.automatic_optimization: + trainer.call_hook("on_after_backward") + with trainer.profiler.profile("optimizer_step"): optimizer.step() diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 5171c95cfd..a96415ff35 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,9 +1,10 @@ import os -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Optional, Union import torch.distributed as torch_distrib -from pytorch_lightning import _logger as log from torch.optim import Optimizer + +from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.plugin import LightningPlugin diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 39a2403fc4..4df5d12847 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -69,6 +69,11 @@ class NativeAMPPlugin(PrecisionPlugin): # TODO: pass the closure to the step ASAP with trainer.profiler.profile("closure"): closure() + + if not self.trainer.train_loop.automatic_optimization: + trainer.scaler.unscale_(optimizer) + trainer.call_hook("on_after_backward") + with trainer.profiler.profile("optimizer_step"): trainer.scaler.step(optimizer) trainer.scaler.update() diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index f8a793af85..52f0df2a53 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -17,7 +17,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin -from pytorch_lightning.utilities import AMPType, FAIRSCALE_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException if FAIRSCALE_AVAILABLE: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 01c0119e85..974bd69229 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -69,6 +69,37 @@ class ConfigValidator(object): ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) + trainer = self.trainer + + trainer.overriden_optimizer_step = is_overridden('optimizer_step', model) + trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model) + + enable_pl_optimizer = trainer._enable_pl_optimizer + automatic_optimization = trainer.train_loop.automatic_optimization + if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization: + rank_zero_warn( + "When overriding `LightningModule` optimizer_step with" + " `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`," + " we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`." + " For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`." + ) + + going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() + + has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad + if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: + raise MisconfigurationException( + 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with ' + '`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.' + ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' + ) + + if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization: + raise MisconfigurationException( + 'When overriding `LightningModule` optimizer_zero_grad with ' + '`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported' + ) + def __verify_eval_loop_configuration(self, model, eval_loop_name): step_name = f'{eval_loop_name}_step' diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 851a48e014..99a0c846fe 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -42,7 +42,7 @@ class LoggerConnector: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self._current_stage) + return self._cached_results.get(self._current_stage) # type: ignore def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: self._current_stage = LoggerStages.determine_stage(stage_or_testing) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 92e3b6af2e..6740406c5f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -855,7 +855,8 @@ class Trainer( model.setup(stage_name) def _reset_result_and_set_hook_fx_name(self, hook_name): - if "batch_start" in hook_name: + # on_before_zero_grad is called within training_step + if "batch_start" in hook_name or "on_before_zero_grad" in hook_name: return True model_ref = self.get_model() if model_ref is not None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 679f59c05e..36db77bf99 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -26,7 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum -from pytorch_lightning.utilities import AMPType, parsing +from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -321,7 +321,6 @@ class TrainLoop: args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics - model_ref._results = Result() model_ref._current_fx_name = 'training_step' model_ref._results = Result() training_step_output = self.trainer.accelerator_backend.training_step(args) @@ -475,21 +474,34 @@ class TrainLoop: return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs): - # optimizer step lightningModule hook - if isinstance(optimizer, LightningOptimizer): - optimizer.step(closure=train_step_and_backward_closure) - else: - with self.trainer.profiler.profile("optimizer_step"): - self.trainer.accelerator_backend.optimizer_step( - optimizer, batch_idx, opt_idx, train_step_and_backward_closure, *args, **kwargs - ) + model_ref = self.trainer.get_model() + + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + + # native amp + lbfgs is a no go right now + if using_native_amp and is_lbfgs: + raise MisconfigurationException( + 'native PyTorch amp and lbfgs are not compatible.' + ' To request, please file a Github issue in PyTorch and tag @mcarilli') + + # model hook + model_ref.optimizer_step( + epoch=self.trainer.current_epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=opt_idx, + optimizer_closure=train_step_and_backward_closure, + on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, + using_native_amp=using_native_amp, + using_lbfgs=is_lbfgs, + *args, + **kwargs, + ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) - def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): - self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() @@ -708,7 +720,6 @@ class TrainLoop: if self._curr_step_result is None: # user decided to skip optimization # make sure to zero grad. - self.zero_grad_handler(batch_idx, optimizer, opt_idx) continue batch_outputs = self._process_closure_result( @@ -720,9 +731,6 @@ class TrainLoop: grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None - # hook + clear gradients - self.zero_grad_handler(batch_idx, optimizer, opt_idx) - # update running loss + reset accumulated loss self.update_running_loss() @@ -947,14 +955,3 @@ class TrainLoop: # reset for next set of accumulated grads self.accumulated_loss.reset() - - def zero_grad_handler(self, batch_idx, optimizer, opt_idx): - if self.automatic_optimization: - # hook - self.on_before_zero_grad(optimizer) - optimizers = enumerate([optimizer]) - else: - optimizers = [] - - for idx, optimizer in optimizers: - self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index bb740b1dcb..c00c712bb3 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from unittest import mock -from unittest.mock import MagicMock, call, ANY +from unittest.mock import ANY, MagicMock, call from pytorch_lightning import Trainer from tests.base import BoringModel diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py new file mode 100644 index 0000000000..0c71259373 --- /dev/null +++ b/tests/core/test_lightning_module.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle +from argparse import ArgumentParser +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +import torch +from torch.optim import SGD, Adam +from torch.utils.data import DataLoader, random_split + +from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import BoringModel + + +def test_automatic_optimization(tmpdir): + class TestModel(BoringModel): + def optimizer_step(self, *_, **__): + pass + + model = TestModel() + + try: + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + accumulate_grad_batches=2, + automatic_optimization=True + ) + + trainer.fit(model) + except MisconfigurationException as e: + assert "It ensures optimizer_step or optimizer_zero_grad are called on every batch" in str(e) + + +@pytest.mark.parametrize("enable_pl_optimizer", [False, True]) +def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir): + + with patch("torch.optim.SGD.step") as sgd_step, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \ + patch("torch.optim.Adam.step") as adam_step, \ + patch("torch.optim.Adam.zero_grad") as adam_zero_grad: + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + + # update generator opt every 2 steps + if optimizer_idx == 0: + if batch_idx % 2 == 0: + assert isinstance(optimizer, SGD) + optimizer.step(closure=optimizer_closure) + if not enable_pl_optimizer: + optimizer.zero_grad() + + # update discriminator opt every 4 steps + if optimizer_idx == 1: + if batch_idx % 4 == 0: + assert isinstance(optimizer, Adam) + optimizer.step(closure=optimizer_closure) + if not enable_pl_optimizer: + optimizer.zero_grad() + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + automatic_optimization=True, + enable_pl_optimizer=enable_pl_optimizer + ) + + trainer.fit(model) + + assert sgd_step.call_count == 4 + assert sgd_zero_grad.call_count == 4 + assert adam_step.call_count == 2 + assert adam_zero_grad.call_count == 2 diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index a120365237..bd19c26784 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -21,6 +21,7 @@ from torch.optim import Adam, Optimizer from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset @@ -188,10 +189,226 @@ def test_state(tmpdir): assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", - "_trainer", "_use_accumulate_grad_batches_from_trainer", "_lightning_step"] + "_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization", + "_accumulate_grad_batches"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v assert lightning_dict == optimizer.__dict__ assert optimizer.state_dict() == lightning_optimizer.state_dict() assert optimizer.state == lightning_optimizer.state + + +def test_lightning_optimizer_automatic_optimization(tmpdir): + """ + Test lightning optimize works with make_optimizer_step in automatic_optimization + """ + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs): + outputs = sum(outputs, []) + torch.stack([x["loss"] for x in outputs]).mean() + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + + optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 2 == 0) + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + optimizer_1 = LightningOptimizer(optimizer_1, 4) + + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=10, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + enable_pl_optimizer=True, + automatic_optimization=True + ) + trainer.fit(model) + + +def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): + """ + Test lightning optimize works with optimizer_zero_grad overrides in automatic_optimization + """ + + with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs): + outputs = sum(outputs, []) + torch.stack([x["loss"] for x in outputs]).mean() + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + if optimizer_idx == 0: + if batch_idx % 2 == 0: + optimizer.zero_grad() + + if optimizer_idx == 1: + if batch_idx % 5 == 0: + optimizer.zero_grad() + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + + optimizer.step(closure=optimizer_closure) + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=10, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + enable_pl_optimizer=True, + automatic_optimization=True + ) + trainer.fit(model) + + assert adam_zero_grad.call_count == 2 + assert sgd_zero_grad.call_count == 5 + + +def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step(tmpdir): + """ + Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization + """ + + try: + with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs): + outputs = sum(outputs, []) + torch.stack([x["loss"] for x in outputs]).mean() + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + if optimizer_idx == 0: + if batch_idx % 2 == 0: + optimizer.zero_grad() + + if optimizer_idx == 1: + if batch_idx % 5 == 0: + optimizer.zero_grad() + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + + if optimizer_idx == 0: + optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 3 == 0) + return + optimizer.step(closure=optimizer_closure) + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=20, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + enable_pl_optimizer=True, + automatic_optimization=True + ) + trainer.fit(model) + + assert adam_zero_grad.call_count == 4 + assert sgd_zero_grad.call_count == 10 + + except MisconfigurationException as e: + assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str(e) + + +def test_lightning_optimizer_automatic_optimization_make_optimizer_step_2(tmpdir): + """ + Test lightning optimize works with make_optimizer_step in automatic_optimization + """ + + with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs): + outputs = sum(outputs, []) + torch.stack([x["loss"] for x in outputs]).mean() + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + + make_optimizer_step = None + if optimizer_idx == 0: + make_optimizer_step = batch_idx % 4 == 0 + optimizer.step(closure=optimizer_closure, make_optimizer_step=make_optimizer_step) + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=20, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + enable_pl_optimizer=True, + automatic_optimization=True, + ) + trainer.fit(model) + + assert adam_zero_grad.call_count == 20 + assert sgd_zero_grad.call_count == 5 diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index 6b1bd1f745..724ebe7c82 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -1,13 +1,15 @@ -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE -from tests.base.boring_model import BoringModel -from pytorch_lightning import Trainer -import pytest import os from unittest import mock -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin + +import pytest import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE +from tests.base.boring_model import BoringModel + @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6") @mock.patch.dict(os.environ, { diff --git a/tests/special_tests.sh b/tests/special_tests.sh new file mode 100644 index 0000000000..8e25f446a5 --- /dev/null +++ b/tests/special_tests.sh @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export PL_RUNNING_SPECIAL_TESTS=1 +# Running special tests +DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 1a904daebe..0df99d1c1e 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities import APEX_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -563,6 +564,15 @@ def test_multiple_optimizers_step(tmpdir): Tests that `step` works with several optimizers """ class TestModel(BoringModel): + + called = False + + def on_after_backward(self): + self.called = True + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 100, norm.item() + def training_step(self, batch, batch_idx, optimizer_idx): # manual (opt_a, opt_b) = self.optimizers() @@ -621,6 +631,7 @@ def test_multiple_optimizers_step(tmpdir): num_manual_backward_calls = 3 assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + assert model.called def test_step_with_optimizer_closure(tmpdir): @@ -891,3 +902,32 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m expected_calls = [call(closure=ANY, optim='adam') for s in range(2)] mock_adam_step.assert_has_calls(expected_calls) + + +def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir): + """ + Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException + """ + try: + class TestModel(BoringModel): + + def optimizer_zero_grad(self, *_): + pass + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 8 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + enable_pl_optimizer=True, + ) + except MisconfigurationException as e: + assert "`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported" in str(e)