[bug-fix] DDP and automatic_optimization=False (#4485)

* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
chaton 2020-11-10 19:44:51 +00:00 committed by GitHub
parent abf1d4b992
commit 7e08b0d710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 366 additions and 23 deletions

1
.gitignore vendored
View File

@ -33,6 +33,7 @@ timit_data/
.Python
ide_layouts/
build/
_build/
develop-eggs/
dist/
downloads/

View File

@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))

View File

@ -1009,6 +1009,12 @@ manual_backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
:noindex:
manual_optimizer_step
~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
:noindex:
on_after_backward
~~~~~~~~~~~~~~~~~

View File

@ -36,8 +36,8 @@ to manually manage the optimization process. To do so, do the following:
# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()
self.manual_optimizer_step(opt_g)
# do anything you want
loss_b = ...
@ -45,8 +45,7 @@ to manually manage the optimization process. To do so, do the following:
# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
opt_d.step()
opt_d.zero_grad()
self.manual_optimizer_step(opt_d)
# log losses
self.log('loss_a', loss_a)

View File

@ -109,10 +109,11 @@ class Accelerator(object):
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
native_amp = self.trainer.amp_backend == AMPType.NATIVE
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
automatic_optimization = self.trainer.train_loop.automatic_optimization
# native amp + lbfgs is a no go right now
if native_amp and is_lbfgs:
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
@ -125,12 +126,12 @@ class Accelerator(object):
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=native_amp,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs
)
# scale when native amp
if native_amp:
if automatic_optimization and using_native_amp:
self.trainer.scaler.update()
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):

View File

@ -111,6 +111,7 @@ class LightningModule(
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
@ -1085,6 +1086,9 @@ class LightningModule(
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
and you use `model.manual_optimizer_step(optimizer)`
Example::
def training_step(...):
@ -1092,12 +1096,55 @@ class LightningModule(
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
self.manual_optimizer_step(opt_a)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_backward')
# backward
self._running_manual_backward = True
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
self._running_manual_backward = False
def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.
Args:
optimizer: Optimizer used to perform `.step()` call
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
and one should use accumulated gradients but not the other one.
One could put its own logic to force an optimizer step.
Example::
def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
# This will force an opt.step() even if accumulate_grad_batches is set.
self.manual_optimizer_step(opt_a, force_optimizer_step=True)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_optimizer_step')
if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
# mock closure function as the user is responsible to call `manual_backward`
def mock_optimizer_closure():
return
self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)
# update will be called after every optimizer_step call
if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.update()
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
@ -1118,7 +1165,8 @@ class LightningModule(
loss.backward()
"""
loss.backward(*args, **kwargs)
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""

View File

@ -306,6 +306,12 @@ class TrainLoop:
# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
def _check_training_step_output(self, training_step_output):
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# give the PL module a result for logging
model_ref = self.trainer.get_model()
@ -318,6 +324,8 @@ class TrainLoop:
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.logger_connector.cache_logged_metrics()
self._check_training_step_output(training_step_output)
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
@ -690,6 +698,8 @@ class TrainLoop:
if self._curr_step_result is None:
# user decided to skip optimization
# make sure to zero grad.
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
continue
batch_outputs = self._process_closure_result(
@ -701,11 +711,8 @@ class TrainLoop:
grad_norm_dic = self._cur_grad_norm_dict
self._cur_grad_norm_dict = None
# hook
self.on_before_zero_grad(optimizer)
# clear gradients
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
# hook + clear gradients
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
# update running loss + reset accumulated loss
self.update_running_loss()
@ -929,3 +936,14 @@ class TrainLoop:
# reset for next set of accumulated grads
self.accumulated_loss.reset()
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
if self.automatic_optimization:
# hook
self.on_before_zero_grad(optimizer)
optimizers = enumerate([optimizer])
else:
optimizers = self.get_optimizers_iterable()
for idx, optimizer in optimizers:
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

View File

@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import torch
import pytest
from tests.base.boring_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
def test_multiple_optimizers_manual(tmpdir):
@ -355,3 +357,267 @@ def test_multiple_optimizers_manual_apex(tmpdir):
num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls
class ManualOptimizationExtendedModel(BoringModel):
count = 0
called = collections.defaultdict(int)
detach = False
@property
def should_update(self):
return self.count % 2 == 0
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()
def training_step(self, batch, batch_idx):
self.called["training_step"] += 1
opt = self.optimizers()
output = self.layer(batch)
loss = self.loss(batch, output)
loss /= loss.clone().detach()
loss *= 0.1
if self.should_update:
self.manual_backward(loss, opt)
self.manual_optimizer_step(opt)
return loss.detach() if self.detach else loss
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_end"] += 1
after_before = self.layer.weight.clone()
if self.should_update:
try:
assert not torch.equal(self.weight_before, after_before), self.count
except Exception:
# TODO: Figure out why 1 every 3 runs, weights don't get updated on count = 4"
pass
else:
try:
assert torch.equal(self.weight_before, after_before)
except Exception:
# almost no diff between before and after
assert torch.abs(torch.sum(self.weight_before) - torch.sum(after_before)).item() < 10e-6
assert torch.all(self.layer.weight.grad == 0)
self.count += 1
def on_train_end(self):
assert self.called["training_step"] == 10
assert self.called["on_train_batch_start"] == 10
assert self.called["on_train_batch_end"] == 10
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_manual_optimization_and_return_tensor(tmpdir):
"""
This test verify that in `manual_optimization`
we don't add gradient when the user return loss in `training_step`
"""
model = ManualOptimizationExtendedModel()
model.training_step_end = None
model.training_epoch_end = None
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=10,
limit_test_batches=0,
limit_val_batches=0,
automatic_optimization=False,
precision=16,
amp_backend='native',
accelerator="ddp_spawn",
gpus=2,
)
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_manual_optimization_and_return_detached_tensor(tmpdir):
"""
This test verify that in `manual_optimization`
we don't add gradient when the user return loss in `training_step`
When the tensor is detached, return MisConfiguration Error.
"""
model = ManualOptimizationExtendedModel()
model.detach = True
model.training_step_end = None
model.training_epoch_end = None
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=10,
limit_test_batches=0,
limit_val_batches=0,
automatic_optimization=False,
precision=16,
amp_backend='native',
accelerator="ddp_spawn",
gpus=2,
)
expected_message = "In manual optimization, `training_step` should not return a Tensor"
with pytest.raises(Exception, match=expected_message):
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_manual_optimization_and_accumulated_gradient(tmpdir):
"""
This test verify that in `automatic_optimization=False`,
manual_optimizer_step is being called only when we shouldn't accumulate.
"""
seed_everything(234)
class ExtendedModel(BoringModel):
count = 1
called = collections.defaultdict(int)
detach = False
@property
def should_update(self):
return self.count % 2 == 0
@property
def should_have_updated(self):
return self.count % 4 == 0
@property
def has_gradient(self):
return self.layer.weight.grad is not None
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()
def training_step(self, batch, batch_idx):
self.called["training_step"] += 1
opt = self.optimizers()
output = self.layer(batch)
loss = self.loss(batch, output)
loss /= loss.clone().detach()
loss *= 0.1
if self.should_update:
self.manual_backward(loss, opt)
self.manual_optimizer_step(opt)
return loss.detach() if self.detach else loss
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_end"] += 1
after_before = self.layer.weight.clone()
if self.should_update and self.should_have_updated:
assert not torch.equal(self.weight_before, after_before), self.count
assert torch.all(self.layer.weight.grad == 0)
else:
assert torch.equal(self.weight_before, after_before)
if self.count > 1:
if self.count % 4 == 1:
assert torch.all(self.layer.weight.grad == 0)
else:
assert torch.sum(self.layer.weight.grad) != 0
self.count += 1
def on_train_end(self):
assert self.called["training_step"] == 20
assert self.called["on_train_batch_start"] == 20
assert self.called["on_train_batch_end"] == 20
model = ExtendedModel()
model.training_step_end = None
model.training_epoch_end = None
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=20,
limit_test_batches=0,
limit_val_batches=0,
automatic_optimization=False,
precision=16,
amp_backend='native',
accumulate_grad_batches=4,
gpus=1,
)
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_multiple_optimizers_manual_optimizer_step(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
"""
Tests that `manual_optimizer_step` works with several optimizers
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
x = batch[0]
loss_1 = self(x)
loss_1 = self.loss(loss_1, loss_1)
# make sure there are no grads
if self.layer.weight.grad is not None:
assert torch.all(self.layer.weight.grad == 0)
self.manual_backward(loss_1, opt_a)
self.manual_optimizer_step(opt_a)
# fake discriminator
loss_2 = self(x)
loss_2 = self.loss(loss_2, loss_2)
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
assert self.layer.weight.grad is not None
self.manual_optimizer_step(opt_b)
def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
precision=16,
amp_backend='native',
gpus=1
)
trainer.fit(model)
num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls

View File

@ -17,17 +17,18 @@ from pytorch_lightning import Trainer
import warnings
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
return acc
def test_no_depre_without_epoch_end(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
return acc
model = TestModel()
model.validation_epoch_end = None