[BUGFIX] AMP + Precision unscale grad (#4441)

* move unscale within Native plugin

* remove gradient tracking from lightning backward

* forgot trainer.fit

* typo

* update

* cleanup

* set to 1.6

* typo

* skip if below 1.6 strict

* update changelog

* remove useless code

* Update tests/plugins/test_amp_plugin.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* Update tests/plugins/test_amp_plugin.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* update changelog

* Update CHANGELOG.md

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
chaton 2020-11-02 16:36:48 +00:00 committed by GitHub
parent 9b8102d1a5
commit 102fa9ee7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 18 deletions

View File

@ -17,12 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
### Changed
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
### Deprecated
- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))
@ -33,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))
- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
## [1.0.4] - 2020-10-27
@ -50,8 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
### Changed
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
@ -78,7 +83,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341))
## [1.0.3] - 2020-10-20
### Added

View File

@ -131,11 +131,6 @@ class Accelerator(object):
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
def clip_gradients(self, optimizer, clip_val=None):
if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.unscale_(optimizer)
# apply clip gradients
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)

View File

@ -1101,7 +1101,6 @@ class LightningModule(
"""
loss.backward(*args, **kwargs)
self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer)
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""

View File

@ -38,6 +38,11 @@ class NativeAMPPlugin:
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate():
self.trainer.scaler.unscale_(optimizer)
return closure_loss
def training_step(self, fx, args):

View File

@ -652,11 +652,6 @@ class TrainLoop:
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
should_accumulate = not (accumulation_done or is_final_batch)
# lightning module hook
splits = self.tbptt_split_batch(batch)
@ -676,7 +671,7 @@ class TrainLoop:
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)
if should_accumulate:
if self.should_accumulate():
# For gradient accumulation
# -------------------
@ -767,7 +762,7 @@ class TrainLoop:
@contextmanager
def block_ddp_sync_behaviour(self):
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
yield from self.trainer.model.no_sync()
yield self.trainer.model.no_sync()
else:
yield
@ -817,8 +812,10 @@ class TrainLoop:
with self.trainer.profiler.profile("model_backward"):
self.backward(result, optimizer, opt_idx)
# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
@ -837,6 +834,10 @@ class TrainLoop:
result.closure_loss, optimizer, opt_idx, *args, **kwargs
)
if not self.should_accumulate():
# track gradients
self.track_and_norm_grad(optimizer=optimizer)
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()
@ -863,6 +864,12 @@ class TrainLoop:
def _num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)
def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

View File

@ -84,3 +84,65 @@ def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
with pytest.raises(SystemExit):
trainer.fit(model)
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale(tmpdir):
class ExtendedBoringModel(BoringModel):
def on_after_backward(self):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.
model = ExtendedBoringModel()
trainer = Trainer(
max_epochs=2,
default_root_dir=os.getcwd(),
limit_train_batches=2,
limit_test_batches=2,
limit_val_batches=2,
amp_backend='native',
distributed_backend='ddp_spawn',
gpus=2,
precision=16,
track_grad_norm=2,
log_every_n_steps=1
)
trainer.fit(model)
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
class ExtendedBoringModel(BoringModel):
def on_after_backward(self):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.
model = ExtendedBoringModel()
trainer = Trainer(
max_epochs=2,
default_root_dir=os.getcwd(),
limit_train_batches=2,
limit_test_batches=2,
limit_val_batches=2,
amp_backend='native',
distributed_backend='ddp_spawn',
gpus=2,
precision=16,
track_grad_norm=2,
log_every_n_steps=1,
accumulate_grad_batches=2,
)
trainer.fit(model)