ref: decouple apex second attemp part 7/n (#4061)

* ref: decouple apex second attemp part 7/n

* ref: decouple apex second attemp part 7/n

* ref: decouple apex second attemp part 7/n
This commit is contained in:
William Falcon 2020-10-10 16:44:15 -04:00 committed by GitHub
parent dca86c310e
commit 5ce9fc6bb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 135 additions and 74 deletions

View File

@ -285,6 +285,41 @@ Example::
# default used by the Trainer
trainer = Trainer(amp_level='O2')
automatic_optimization
^^^^^^^^^^^^^^^^^^^^^^
When set to False, Lightning does not automate the optimization process. This means you are responsible for your own
optimizer behavior
Example::
def training_step(self, batch, batch_idx):
(opt) = self.optimizers()
loss = ...
self.backward(loss, opt)
opt.step()
opt.zero_grad()
This is not recommended when using a single optimizer, instead it's recommended when using 2+ optimizers
AND you are an expert user. Most useful for research like RL, sparse coding and GAN research.
In the multi-optimizer case, ignore the optimizer_idx flag and use the optimizers directly
Example::
def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()
gen_loss = ...
self.backward(gen_loss, opt_a)
opt_a.step()
opt_a.zero_grad()
disc_loss = ...
self.backward(disc_loss, opt_b)
opt_b.step()
opt_b.zero_grad()
auto_scale_batch_size
^^^^^^^^^^^^^^^^^^^^^

View File

@ -134,6 +134,7 @@ class Trainer(
amp_backend: str = 'native',
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,
automatic_optimization: bool = True,
):
r"""
Customize every aspect of training via flags
@ -201,6 +202,9 @@ class Trainer(
log_every_n_steps: How often to log within steps (defaults to every 50 steps).
automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad.
Meant to be used with multiple optimizers by advanced users.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
@ -337,7 +341,14 @@ class Trainer(
)
# init train loop related flags
self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.train_loop.on_trainer_init(
max_epochs,
min_epochs,
max_steps,
min_steps,
num_sanity_val_steps,
automatic_optimization
)
self.evaluation_loop.on_trainer_init()
# configure tuner

View File

@ -44,8 +44,9 @@ class TrainLoop:
self.warning_cache = WarningCache()
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
self.automatic_optimization = True
def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization):
self.trainer.global_step = 0
self.trainer.current_epoch = 0
self.trainer.interrupted = False
@ -56,6 +57,7 @@ class TrainLoop:
self.trainer.batch_idx = 0
self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None
self.automatic_optimization = automatic_optimization
self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
@ -275,7 +277,7 @@ class TrainLoop:
# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.trainer.optimizers[opt_idx])]
return [[opt_idx, self.trainer.optimizers[opt_idx]]]
def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
is_result_obj = isinstance(training_step_output, Result)
@ -640,11 +642,16 @@ class TrainLoop:
for split_idx, split_batch in enumerate(splits):
self.trainer.split_idx = split_idx
# in manual optimization we loop over all optimizers at once
optimizers = self.get_optimizers_iterable()
if not self.automatic_optimization:
optimizers = [optimizers[0]]
# loop over optimizers
for opt_idx, optimizer in self.get_optimizers_iterable():
for opt_idx, optimizer in optimizers:
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.trainer.optimizers) > 1:
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)
@ -751,12 +758,18 @@ class TrainLoop:
return result
def backward(self, result, optimizer, *args, **kwargs):
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss,
optimizer,
*args,
**kwargs
)
self.trainer.dev_debugger.track_event('backward_call')
# backward can be called manually in the training loop.
if isinstance(result, torch.Tensor):
self.trainer.accelerator_backend.backward(result, optimizer, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss,
optimizer,
*args,
**kwargs
)
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

View File

@ -61,17 +61,17 @@ def test_multiple_optimizers_manual(tmpdir):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.trainer.optimizers
(opt_a, opt_b) = self.optimizers()
loss_1 = self.step(batch[0])
# fake generator
loss_1.backward()
self.backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
# fake discriminator
loss_2 = self.step(batch[0])
loss_2.backward()
self.backward(loss_2, opt_b)
opt_b.step()
opt_b.zero_grad()
@ -88,6 +88,7 @@ def test_multiple_optimizers_manual(tmpdir):
model.val_dataloader = None
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,

View File

@ -7,66 +7,67 @@ from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
# def test_multiple_optimizers_manual(tmpdir):
# os.environ['PL_DEV_DEBUG'] = '1'
#
# """
# Tests that only training_step can be used
# """
# class TestModel(BoringModel):
# def training_step(self, batch, batch_idx, optimizer_idx):
# # manual
# (opt_a, opt_b) = self.optimizers()
# loss_1 = self.step(batch[0])
#
# # make sure there are no grads
# if batch_idx > 0:
# assert torch.all(self.layer.weight.grad == 0)
#
# self.backward(loss_1, opt_a)
# opt_a.step()
# opt_a.zero_grad()
# assert torch.all(self.layer.weight.grad == 0)
#
# # fake discriminator
# loss_2 = self.step(batch[0])
#
# # ensure we forward the correct params to the optimizer
# # without retain_graph we can't do multiple backward passes
# self.backward(loss_2, opt_b, retain_graph=True)
# self.backward(loss_2, opt_a, retain_graph=True)
#
# assert self.layer.weight.grad is not None
# opt_b.step()
# opt_b.zero_grad()
# assert torch.all(self.layer.weight.grad == 0)
#
# def training_epoch_end(self, outputs) -> None:
# # outputs should be an array with an entry per optimizer
# assert len(outputs) == 2
#
# def configure_optimizers(self):
# optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
# optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
# return optimizer, optimizer_2
#
# model = TestModel()
# model.val_dataloader = None
#
# limit_train_batches = 2
# trainer = Trainer(
# default_root_dir=tmpdir,
# limit_train_batches=limit_train_batches,
# limit_val_batches=2,
# max_epochs=1,
# log_every_n_steps=1,
# weights_summary=None,
# )
#
# trainer.fit(model)
#
# num_manual_backward_calls = 3
# assert len(trainer.dev_debugger.backward_calls) == limit_train_batches * num_manual_backward_calls
def test_multiple_optimizers_manual(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
"""
Tests that only training_step can be used
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
loss_1 = self.step(batch[0])
# make sure there are no grads
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)
self.backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
# fake discriminator
loss_2 = self.step(batch[0])
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
assert self.layer.weight.grad is not None
opt_b.step()
opt_b.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls
#
# def test_multiple_optimizers_manual_single_optimizer_called(tmpdir):