diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index caf2e85e03..657bd31534 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -285,6 +285,41 @@ Example:: # default used by the Trainer trainer = Trainer(amp_level='O2') +automatic_optimization +^^^^^^^^^^^^^^^^^^^^^^ +When set to False, Lightning does not automate the optimization process. This means you are responsible for your own +optimizer behavior + +Example:: + + def training_step(self, batch, batch_idx): + (opt) = self.optimizers() + + loss = ... + self.backward(loss, opt) + opt.step() + opt.zero_grad() + +This is not recommended when using a single optimizer, instead it's recommended when using 2+ optimizers +AND you are an expert user. Most useful for research like RL, sparse coding and GAN research. + +In the multi-optimizer case, ignore the optimizer_idx flag and use the optimizers directly + +Example:: + + def training_step(self, batch, batch_idx, optimizer_idx): + (opt_a, opt_b) = self.optimizers() + + gen_loss = ... + self.backward(gen_loss, opt_a) + opt_a.step() + opt_a.zero_grad() + + disc_loss = ... + self.backward(disc_loss, opt_b) + opt_b.step() + opt_b.zero_grad() + auto_scale_batch_size ^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bac58e96dd..15a941071d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -134,6 +134,7 @@ class Trainer( amp_backend: str = 'native', amp_level: str = 'O2', distributed_backend: Optional[str] = None, + automatic_optimization: bool = True, ): r""" Customize every aspect of training via flags @@ -201,6 +202,9 @@ class Trainer( log_every_n_steps: How often to log within steps (defaults to every 50 steps). + automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad. + Meant to be used with multiple optimizers by advanced users. + prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data @@ -337,7 +341,14 @@ class Trainer( ) # init train loop related flags - self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) + self.train_loop.on_trainer_init( + max_epochs, + min_epochs, + max_steps, + min_steps, + num_sanity_val_steps, + automatic_optimization + ) self.evaluation_loop.on_trainer_init() # configure tuner diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 46badc5b4c..e655518567 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -44,8 +44,9 @@ class TrainLoop: self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) + self.automatic_optimization = True - def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps): + def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.interrupted = False @@ -56,6 +57,7 @@ class TrainLoop: self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None + self.automatic_optimization = automatic_optimization self.trainer.max_epochs = max_epochs self.trainer.min_epochs = min_epochs @@ -275,7 +277,7 @@ class TrainLoop: # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) - return [(opt_idx, self.trainer.optimizers[opt_idx])] + return [[opt_idx, self.trainer.optimizers[opt_idx]]] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) @@ -640,11 +642,16 @@ class TrainLoop: for split_idx, split_batch in enumerate(splits): self.trainer.split_idx = split_idx + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + # loop over optimizers - for opt_idx, optimizer in self.get_optimizers_iterable(): + for opt_idx, optimizer in optimizers: # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. - if len(self.trainer.optimizers) > 1: + if self.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) @@ -751,12 +758,18 @@ class TrainLoop: return result def backward(self, result, optimizer, *args, **kwargs): - result.closure_loss = self.trainer.accelerator_backend.backward( - result.closure_loss, - optimizer, - *args, - **kwargs - ) + self.trainer.dev_debugger.track_event('backward_call') + + # backward can be called manually in the training loop. + if isinstance(result, torch.Tensor): + self.trainer.accelerator_backend.backward(result, optimizer, *args, **kwargs) + else: + result.closure_loss = self.trainer.accelerator_backend.backward( + result.closure_loss, + optimizer, + *args, + **kwargs + ) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 diff --git a/tests/trainer/dynamic_args/test_multiple_optimizers.py b/tests/trainer/dynamic_args/test_multiple_optimizers.py index 49eb654dce..6d6b9d5cde 100644 --- a/tests/trainer/dynamic_args/test_multiple_optimizers.py +++ b/tests/trainer/dynamic_args/test_multiple_optimizers.py @@ -61,17 +61,17 @@ def test_multiple_optimizers_manual(tmpdir): def training_step(self, batch, batch_idx, optimizer_idx): # manual - (opt_a, opt_b) = self.trainer.optimizers + (opt_a, opt_b) = self.optimizers() loss_1 = self.step(batch[0]) # fake generator - loss_1.backward() + self.backward(loss_1, opt_a) opt_a.step() opt_a.zero_grad() # fake discriminator loss_2 = self.step(batch[0]) - loss_2.backward() + self.backward(loss_2, opt_b) opt_b.step() opt_b.zero_grad() @@ -88,6 +88,7 @@ def test_multiple_optimizers_manual(tmpdir): model.val_dataloader = None trainer = Trainer( + automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5e6b1802c5..54ad48b96d 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -7,66 +7,67 @@ from pytorch_lightning.utilities import APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -# def test_multiple_optimizers_manual(tmpdir): -# os.environ['PL_DEV_DEBUG'] = '1' -# -# """ -# Tests that only training_step can be used -# """ -# class TestModel(BoringModel): -# def training_step(self, batch, batch_idx, optimizer_idx): -# # manual -# (opt_a, opt_b) = self.optimizers() -# loss_1 = self.step(batch[0]) -# -# # make sure there are no grads -# if batch_idx > 0: -# assert torch.all(self.layer.weight.grad == 0) -# -# self.backward(loss_1, opt_a) -# opt_a.step() -# opt_a.zero_grad() -# assert torch.all(self.layer.weight.grad == 0) -# -# # fake discriminator -# loss_2 = self.step(batch[0]) -# -# # ensure we forward the correct params to the optimizer -# # without retain_graph we can't do multiple backward passes -# self.backward(loss_2, opt_b, retain_graph=True) -# self.backward(loss_2, opt_a, retain_graph=True) -# -# assert self.layer.weight.grad is not None -# opt_b.step() -# opt_b.zero_grad() -# assert torch.all(self.layer.weight.grad == 0) -# -# def training_epoch_end(self, outputs) -> None: -# # outputs should be an array with an entry per optimizer -# assert len(outputs) == 2 -# -# def configure_optimizers(self): -# optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) -# optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) -# return optimizer, optimizer_2 -# -# model = TestModel() -# model.val_dataloader = None -# -# limit_train_batches = 2 -# trainer = Trainer( -# default_root_dir=tmpdir, -# limit_train_batches=limit_train_batches, -# limit_val_batches=2, -# max_epochs=1, -# log_every_n_steps=1, -# weights_summary=None, -# ) -# -# trainer.fit(model) -# -# num_manual_backward_calls = 3 -# assert len(trainer.dev_debugger.backward_calls) == limit_train_batches * num_manual_backward_calls +def test_multiple_optimizers_manual(tmpdir): + os.environ['PL_DEV_DEBUG'] = '1' + + """ + Tests that only training_step can be used + """ + class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + loss_1 = self.step(batch[0]) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.backward(loss_2, opt_b, retain_graph=True) + self.backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls # # def test_multiple_optimizers_manual_single_optimizer_called(tmpdir):