From 102fa9ee7dd087ee167b120ea7812360928408f7 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 2 Nov 2020 16:36:48 +0000 Subject: [PATCH] [BUGFIX] AMP + Precision unscale grad (#4441) * move unscale within Native plugin * remove gradient tracking from lightning backward * forgot trainer.fit * typo * update * cleanup * set to 1.6 * typo * skip if below 1.6 strict * update changelog * remove useless code * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren * update changelog * Update CHANGELOG.md Co-authored-by: Sean Naren Co-authored-by: Jeff Yang --- CHANGELOG.md | 10 ++- pytorch_lightning/accelerators/accelerator.py | 5 -- pytorch_lightning/core/lightning.py | 1 - pytorch_lightning/plugins/native_amp.py | 5 ++ pytorch_lightning/trainer/training_loop.py | 25 +++++--- tests/plugins/test_amp_plugin.py | 62 +++++++++++++++++++ 6 files changed, 90 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb9ddc804..84d483dd03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,12 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) +- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) + - Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) ### Changed - W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) +- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + +- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + ### Deprecated - Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) @@ -33,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ## [1.0.4] - 2020-10-27 @@ -50,8 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) -- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) - ### Changed - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) @@ -78,7 +83,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) - ## [1.0.3] - 2020-10-20 ### Added diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8e1969cc93..8ece6c4ec1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -131,11 +131,6 @@ class Accelerator(object): model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def clip_gradients(self, optimizer, clip_val=None): - - if self.trainer.amp_backend == AMPType.NATIVE: - self.trainer.scaler.unscale_(optimizer) - - # apply clip gradients # TODO: separate TPU case from here self._clip_gradients(optimizer, clip_val) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 22d63d0a03..d7125eb171 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1101,7 +1101,6 @@ class LightningModule( """ loss.backward(*args, **kwargs) - self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer) def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): """ diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 6506540bde..b016b6c5d2 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -38,6 +38,11 @@ class NativeAMPPlugin: # once backward has been applied, release graph closure_loss = closure_loss.detach() + + # unscale gradient to allow analyze within `on_after_backward` + if not self.trainer.train_loop.should_accumulate(): + self.trainer.scaler.unscale_(optimizer) + return closure_loss def training_step(self, fx, args): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d1dfb3eec3..0d269c333b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -652,11 +652,6 @@ class TrainLoop: if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._num_training_batches_reached() - should_accumulate = not (accumulation_done or is_final_batch) - # lightning module hook splits = self.tbptt_split_batch(batch) @@ -676,7 +671,7 @@ class TrainLoop: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) - if should_accumulate: + if self.should_accumulate(): # For gradient accumulation # ------------------- @@ -767,7 +762,7 @@ class TrainLoop: @contextmanager def block_ddp_sync_behaviour(self): if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): - yield from self.trainer.model.no_sync() + yield self.trainer.model.no_sync() else: yield @@ -817,8 +812,10 @@ class TrainLoop: with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) - # hook - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: @@ -837,6 +834,10 @@ class TrainLoop: result.closure_loss, optimizer, opt_idx, *args, **kwargs ) + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() @@ -863,6 +864,12 @@ class TrainLoop: def _num_training_batches_reached(self): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index c0d5747b5f..6fd000b61d 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -84,3 +84,65 @@ def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1 + ) + trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + trainer.fit(model)