[BUGFIX] AMP + Precision unscale grad (#4441)
* move unscale within Native plugin * remove gradient tracking from lightning backward * forgot trainer.fit * typo * update * cleanup * set to 1.6 * typo * skip if below 1.6 strict * update changelog * remove useless code * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> * update changelog * Update CHANGELOG.md Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
9b8102d1a5
commit
102fa9ee7d
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -17,12 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
|
||||
|
||||
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
|
||||
|
||||
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
|
||||
|
||||
### Changed
|
||||
|
||||
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
|
||||
|
||||
- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||
|
||||
- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))
|
||||
|
@ -33,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))
|
||||
|
||||
- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||
|
||||
## [1.0.4] - 2020-10-27
|
||||
|
||||
|
@ -50,8 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
|
||||
|
||||
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
|
||||
|
||||
### Changed
|
||||
|
||||
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
|
||||
|
@ -78,7 +83,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341))
|
||||
|
||||
|
||||
## [1.0.3] - 2020-10-20
|
||||
|
||||
### Added
|
||||
|
|
|
@ -131,11 +131,6 @@ class Accelerator(object):
|
|||
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def clip_gradients(self, optimizer, clip_val=None):
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
self.trainer.scaler.unscale_(optimizer)
|
||||
|
||||
# apply clip gradients
|
||||
# TODO: separate TPU case from here
|
||||
self._clip_gradients(optimizer, clip_val)
|
||||
|
||||
|
|
|
@ -1101,7 +1101,6 @@ class LightningModule(
|
|||
|
||||
"""
|
||||
loss.backward(*args, **kwargs)
|
||||
self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer)
|
||||
|
||||
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
|
||||
"""
|
||||
|
|
|
@ -38,6 +38,11 @@ class NativeAMPPlugin:
|
|||
|
||||
# once backward has been applied, release graph
|
||||
closure_loss = closure_loss.detach()
|
||||
|
||||
# unscale gradient to allow analyze within `on_after_backward`
|
||||
if not self.trainer.train_loop.should_accumulate():
|
||||
self.trainer.scaler.unscale_(optimizer)
|
||||
|
||||
return closure_loss
|
||||
|
||||
def training_step(self, fx, args):
|
||||
|
|
|
@ -652,11 +652,6 @@ class TrainLoop:
|
|||
if response == -1:
|
||||
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
|
||||
|
||||
# checks if backward or backward + optimizer step (via closure)
|
||||
accumulation_done = self._accumulated_batches_reached()
|
||||
is_final_batch = self._num_training_batches_reached()
|
||||
should_accumulate = not (accumulation_done or is_final_batch)
|
||||
|
||||
# lightning module hook
|
||||
splits = self.tbptt_split_batch(batch)
|
||||
|
||||
|
@ -676,7 +671,7 @@ class TrainLoop:
|
|||
model = self.trainer.get_model()
|
||||
model.toggle_optimizer(optimizer, opt_idx)
|
||||
|
||||
if should_accumulate:
|
||||
if self.should_accumulate():
|
||||
# For gradient accumulation
|
||||
|
||||
# -------------------
|
||||
|
@ -767,7 +762,7 @@ class TrainLoop:
|
|||
@contextmanager
|
||||
def block_ddp_sync_behaviour(self):
|
||||
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
|
||||
yield from self.trainer.model.no_sync()
|
||||
yield self.trainer.model.no_sync()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
@ -817,7 +812,9 @@ class TrainLoop:
|
|||
with self.trainer.profiler.profile("model_backward"):
|
||||
self.backward(result, optimizer, opt_idx)
|
||||
|
||||
# hook
|
||||
# hook - call this hook only
|
||||
# when gradients have finished to accumulate
|
||||
if not self.should_accumulate():
|
||||
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
|
||||
|
||||
# check if loss or model weights are nan
|
||||
|
@ -837,6 +834,10 @@ class TrainLoop:
|
|||
result.closure_loss, optimizer, opt_idx, *args, **kwargs
|
||||
)
|
||||
|
||||
if not self.should_accumulate():
|
||||
# track gradients
|
||||
self.track_and_norm_grad(optimizer=optimizer)
|
||||
|
||||
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
|
||||
num_accumulated_batches_reached = self._accumulated_batches_reached()
|
||||
num_training_batches_reached = self._num_training_batches_reached()
|
||||
|
@ -863,6 +864,12 @@ class TrainLoop:
|
|||
def _num_training_batches_reached(self):
|
||||
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
|
||||
|
||||
def should_accumulate(self):
|
||||
# checks if backward or backward + optimizer step (via closure)
|
||||
accumulation_done = self._accumulated_batches_reached()
|
||||
is_final_batch = self._num_training_batches_reached()
|
||||
return not (accumulation_done or is_final_batch)
|
||||
|
||||
def should_check_val_fx(self, batch_idx, is_last_batch):
|
||||
# decide if we should run validation
|
||||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
|
||||
|
|
|
@ -84,3 +84,65 @@ def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
|||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
|
||||
reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_amp_gradient_unscale(tmpdir):
|
||||
|
||||
class ExtendedBoringModel(BoringModel):
|
||||
|
||||
def on_after_backward(self):
|
||||
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
|
||||
if not (torch.isinf(norm) or torch.isnan(norm)):
|
||||
assert norm.item() < 15.
|
||||
|
||||
model = ExtendedBoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=2,
|
||||
limit_val_batches=2,
|
||||
amp_backend='native',
|
||||
distributed_backend='ddp_spawn',
|
||||
gpus=2,
|
||||
precision=16,
|
||||
track_grad_norm=2,
|
||||
log_every_n_steps=1
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
|
||||
|
||||
class ExtendedBoringModel(BoringModel):
|
||||
|
||||
def on_after_backward(self):
|
||||
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
|
||||
if not (torch.isinf(norm) or torch.isnan(norm)):
|
||||
assert norm.item() < 15.
|
||||
|
||||
model = ExtendedBoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=2,
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=2,
|
||||
limit_val_batches=2,
|
||||
amp_backend='native',
|
||||
distributed_backend='ddp_spawn',
|
||||
gpus=2,
|
||||
precision=16,
|
||||
track_grad_norm=2,
|
||||
log_every_n_steps=1,
|
||||
accumulate_grad_batches=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue