ref: decouple apex second attemp part 9/n (#4063)

* ref: decouple apex second attemp part 9/n

* ref: decouple apex second attemp part 9/n
This commit is contained in:
William Falcon 2020-10-10 18:44:24 -04:00 committed by GitHub
parent e3717ed36e
commit dbfe2b6129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 20 deletions

View File

@ -22,7 +22,7 @@ For advanced research topics like reinforcement learning, sparse coding, or GAN
to manually manage the optimization process. To do so, do the following:
* Ignore the optimizer_idx argument
* So we can scale the loss automatically for you use self.backward(loss) instead of loss.backward()
* So we can scale the loss automatically for you use self.manual_backward(loss) instead of loss.backward()
.. code-block:: python
@ -34,7 +34,7 @@ to manually manage the optimization process. To do so, do the following:
loss_a = ...
# use self.backward which will also handle scaling the loss when using amp
self.backward(loss_a, opt_g)
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()
@ -42,8 +42,8 @@ to manually manage the optimization process. To do so, do the following:
loss_b = ...
# pass in any args that loss.backward() normally takes
self.backward(loss_b, opt_d, retain_graph=True)
self.backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
loss_b.step()
loss_b.zero_grad()

View File

@ -1038,7 +1038,7 @@ class LightningModule(
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
)
def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
@ -1051,10 +1051,31 @@ class LightningModule(
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.backward(loss, opt_a)
self.manual_backward(loss, opt_a)
"""
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Override backward with your own implementation if you need to.
Args:
loss: Loss is already scaled by accumulated grads
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
Example::
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
"""
loss.backward()
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Makes sure only the gradients of the current optimizer's parameters are calculated

View File

@ -296,7 +296,7 @@ Example::
(opt) = self.optimizers()
loss = ...
self.backward(loss, opt)
self.manual_backward(loss, opt)
opt.step()
opt.zero_grad()
@ -311,12 +311,12 @@ Example::
(opt_a, opt_b) = self.optimizers()
gen_loss = ...
self.backward(gen_loss, opt_a)
self.manual_backward(gen_loss, opt_a)
opt_a.step()
opt_a.zero_grad()
disc_loss = ...
self.backward(disc_loss, opt_b)
self.manual_backward(disc_loss, opt_b)
opt_b.step()
opt_b.zero_grad()

View File

@ -65,13 +65,13 @@ def test_multiple_optimizers_manual(tmpdir):
loss_1 = self.step(batch[0])
# fake generator
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
# fake discriminator
loss_2 = self.step(batch[0])
self.backward(loss_2, opt_b)
self.manual_backward(loss_2, opt_b)
opt_b.step()
opt_b.zero_grad()

View File

@ -23,7 +23,7 @@ def test_multiple_optimizers_manual(tmpdir):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
@ -33,8 +33,8 @@ def test_multiple_optimizers_manual(tmpdir):
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
assert self.layer.weight.grad is not None
opt_b.step()
@ -87,7 +87,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
@ -97,8 +97,8 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
assert self.layer.weight.grad is not None
opt_b.step()
@ -157,7 +157,7 @@ def test_multiple_optimizers_manual_apex(tmpdir):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
@ -168,8 +168,8 @@ def test_multiple_optimizers_manual_apex(tmpdir):
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
assert self.layer.weight.grad is not None
opt_b.step()