ref: decouple apex second attemp part 9/n (#4063)
* ref: decouple apex second attemp part 9/n * ref: decouple apex second attemp part 9/n
This commit is contained in:
parent
e3717ed36e
commit
dbfe2b6129
|
@ -22,7 +22,7 @@ For advanced research topics like reinforcement learning, sparse coding, or GAN
|
|||
to manually manage the optimization process. To do so, do the following:
|
||||
|
||||
* Ignore the optimizer_idx argument
|
||||
* So we can scale the loss automatically for you use self.backward(loss) instead of loss.backward()
|
||||
* So we can scale the loss automatically for you use self.manual_backward(loss) instead of loss.backward()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -34,7 +34,7 @@ to manually manage the optimization process. To do so, do the following:
|
|||
loss_a = ...
|
||||
|
||||
# use self.backward which will also handle scaling the loss when using amp
|
||||
self.backward(loss_a, opt_g)
|
||||
self.manual_backward(loss_a, opt_g)
|
||||
opt_g.step()
|
||||
opt_g.zero_grad()
|
||||
|
||||
|
@ -42,8 +42,8 @@ to manually manage the optimization process. To do so, do the following:
|
|||
loss_b = ...
|
||||
|
||||
# pass in any args that loss.backward() normally takes
|
||||
self.backward(loss_b, opt_d, retain_graph=True)
|
||||
self.backward(loss_b, opt_d, retain_graph=True)
|
||||
self.manual_backward(loss_b, opt_d, retain_graph=True)
|
||||
self.manual_backward(loss_b, opt_d, retain_graph=True)
|
||||
loss_b.step()
|
||||
loss_b.zero_grad()
|
||||
|
||||
|
|
|
@ -1038,7 +1038,7 @@ class LightningModule(
|
|||
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
|
||||
def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
|
||||
"""
|
||||
Call this directly from your training_step when doing optimizations manually.
|
||||
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
|
||||
|
@ -1051,10 +1051,31 @@ class LightningModule(
|
|||
(opt_a, opt_b) = self.optimizers()
|
||||
loss = ...
|
||||
# automatically applies scaling, etc...
|
||||
self.backward(loss, opt_a)
|
||||
self.manual_backward(loss, opt_a)
|
||||
"""
|
||||
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)
|
||||
|
||||
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
|
||||
"""
|
||||
Override backward with your own implementation if you need to.
|
||||
|
||||
Args:
|
||||
loss: Loss is already scaled by accumulated grads
|
||||
optimizer: Current optimizer being used
|
||||
optimizer_idx: Index of the current optimizer being used
|
||||
|
||||
Called to perform backward step.
|
||||
Feel free to override as needed.
|
||||
The loss passed in has already been scaled for accumulated gradients if requested.
|
||||
|
||||
Example::
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
||||
"""
|
||||
loss.backward()
|
||||
|
||||
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
|
||||
"""
|
||||
Makes sure only the gradients of the current optimizer's parameters are calculated
|
||||
|
|
|
@ -296,7 +296,7 @@ Example::
|
|||
(opt) = self.optimizers()
|
||||
|
||||
loss = ...
|
||||
self.backward(loss, opt)
|
||||
self.manual_backward(loss, opt)
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
|
@ -311,12 +311,12 @@ Example::
|
|||
(opt_a, opt_b) = self.optimizers()
|
||||
|
||||
gen_loss = ...
|
||||
self.backward(gen_loss, opt_a)
|
||||
self.manual_backward(gen_loss, opt_a)
|
||||
opt_a.step()
|
||||
opt_a.zero_grad()
|
||||
|
||||
disc_loss = ...
|
||||
self.backward(disc_loss, opt_b)
|
||||
self.manual_backward(disc_loss, opt_b)
|
||||
opt_b.step()
|
||||
opt_b.zero_grad()
|
||||
|
||||
|
|
|
@ -65,13 +65,13 @@ def test_multiple_optimizers_manual(tmpdir):
|
|||
loss_1 = self.step(batch[0])
|
||||
|
||||
# fake generator
|
||||
self.backward(loss_1, opt_a)
|
||||
self.manual_backward(loss_1, opt_a)
|
||||
opt_a.step()
|
||||
opt_a.zero_grad()
|
||||
|
||||
# fake discriminator
|
||||
loss_2 = self.step(batch[0])
|
||||
self.backward(loss_2, opt_b)
|
||||
self.manual_backward(loss_2, opt_b)
|
||||
opt_b.step()
|
||||
opt_b.zero_grad()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ def test_multiple_optimizers_manual(tmpdir):
|
|||
if batch_idx > 0:
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
||||
self.backward(loss_1, opt_a)
|
||||
self.manual_backward(loss_1, opt_a)
|
||||
opt_a.step()
|
||||
opt_a.zero_grad()
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
@ -33,8 +33,8 @@ def test_multiple_optimizers_manual(tmpdir):
|
|||
|
||||
# ensure we forward the correct params to the optimizer
|
||||
# without retain_graph we can't do multiple backward passes
|
||||
self.backward(loss_2, opt_b, retain_graph=True)
|
||||
self.backward(loss_2, opt_a, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_b, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_a, retain_graph=True)
|
||||
|
||||
assert self.layer.weight.grad is not None
|
||||
opt_b.step()
|
||||
|
@ -87,7 +87,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
|
|||
if batch_idx > 0:
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
||||
self.backward(loss_1, opt_a)
|
||||
self.manual_backward(loss_1, opt_a)
|
||||
opt_a.step()
|
||||
opt_a.zero_grad()
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
@ -97,8 +97,8 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
|
|||
|
||||
# ensure we forward the correct params to the optimizer
|
||||
# without retain_graph we can't do multiple backward passes
|
||||
self.backward(loss_2, opt_b, retain_graph=True)
|
||||
self.backward(loss_2, opt_a, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_b, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_a, retain_graph=True)
|
||||
|
||||
assert self.layer.weight.grad is not None
|
||||
opt_b.step()
|
||||
|
@ -157,7 +157,7 @@ def test_multiple_optimizers_manual_apex(tmpdir):
|
|||
if batch_idx > 0:
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
||||
self.backward(loss_1, opt_a)
|
||||
self.manual_backward(loss_1, opt_a)
|
||||
opt_a.step()
|
||||
opt_a.zero_grad()
|
||||
assert torch.all(self.layer.weight.grad == 0)
|
||||
|
@ -168,8 +168,8 @@ def test_multiple_optimizers_manual_apex(tmpdir):
|
|||
|
||||
# ensure we forward the correct params to the optimizer
|
||||
# without retain_graph we can't do multiple backward passes
|
||||
self.backward(loss_2, opt_b, retain_graph=True)
|
||||
self.backward(loss_2, opt_a, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_b, retain_graph=True)
|
||||
self.manual_backward(loss_2, opt_a, retain_graph=True)
|
||||
|
||||
assert self.layer.weight.grad is not None
|
||||
opt_b.step()
|
||||
|
|
Loading…
Reference in New Issue