2021-04-30 17:16:28 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-03-03 07:56:57 +00:00
|
|
|
import os
|
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
|
2021-08-24 09:47:21 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-03-03 07:56:57 +00:00
|
|
|
from tests.helpers import BoringModel
|
|
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
|
|
|
|
|
|
class MyNativeAMP(NativeMixedPrecisionPlugin):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class MyApexPlugin(ApexMixedPrecisionPlugin):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(
|
2021-07-26 11:37:35 +00:00
|
|
|
os.environ,
|
|
|
|
{
|
2021-03-03 07:56:57 +00:00
|
|
|
"CUDA_VISIBLE_DEVICES": "0,1",
|
|
|
|
"SLURM_NTASKS": "2",
|
|
|
|
"SLURM_JOB_NAME": "SOME_NAME",
|
|
|
|
"SLURM_NODEID": "0",
|
|
|
|
"LOCAL_RANK": "0",
|
2021-04-13 18:07:40 +00:00
|
|
|
"SLURM_PROCID": "0",
|
2021-03-03 07:56:57 +00:00
|
|
|
"SLURM_LOCALID": "0",
|
2021-07-26 11:37:35 +00:00
|
|
|
},
|
2021-03-03 07:56:57 +00:00
|
|
|
)
|
2021-07-26 11:37:35 +00:00
|
|
|
@mock.patch("torch.cuda.device_count", return_value=2)
|
2021-10-28 14:13:53 +00:00
|
|
|
@pytest.mark.parametrize("strategy,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)])
|
2021-03-03 07:56:57 +00:00
|
|
|
@pytest.mark.parametrize(
|
2021-07-26 11:37:35 +00:00
|
|
|
"amp,custom_plugin,plugin_cls",
|
|
|
|
[
|
2021-09-29 13:34:26 +00:00
|
|
|
("native", False, NativeMixedPrecisionPlugin),
|
|
|
|
("native", True, MyNativeAMP),
|
2021-07-26 11:37:35 +00:00
|
|
|
pytest.param("apex", False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)),
|
|
|
|
pytest.param("apex", True, MyApexPlugin, marks=RunIf(amp_apex=True)),
|
|
|
|
],
|
2021-03-03 07:56:57 +00:00
|
|
|
)
|
2021-10-28 14:13:53 +00:00
|
|
|
def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, plugin_cls):
|
|
|
|
plugin = None
|
|
|
|
if custom_plugin:
|
|
|
|
plugin = plugin_cls(16, "cpu") if amp == "native" else plugin_cls()
|
2021-03-03 07:56:57 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
fast_dev_run=True,
|
|
|
|
precision=16,
|
|
|
|
amp_backend=amp,
|
|
|
|
gpus=gpus,
|
2021-10-28 14:13:53 +00:00
|
|
|
strategy=strategy,
|
|
|
|
plugins=plugin,
|
2021-03-03 07:56:57 +00:00
|
|
|
)
|
|
|
|
assert isinstance(trainer.precision_plugin, plugin_cls)
|
|
|
|
|
|
|
|
|
2021-10-30 10:27:49 +00:00
|
|
|
class TestClippingOptimizer(torch.optim.SGD):
|
|
|
|
def step(self, *args, pl_module=None):
|
|
|
|
pl_module.check_grads_clipped()
|
|
|
|
return super().step(*args)
|
|
|
|
|
|
|
|
|
|
|
|
class TestPrecisionModel(BoringModel):
|
2021-10-28 15:23:27 +00:00
|
|
|
# sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step
|
|
|
|
def on_after_backward(self) -> None:
|
|
|
|
# check grads are scaled
|
|
|
|
scale = self.trainer.precision_plugin.scaler.get_scale()
|
|
|
|
assert scale != 1.0 # the return value if not enabled
|
|
|
|
grads = [p.grad for p in self.parameters()]
|
|
|
|
inv_scale = 1 / scale
|
|
|
|
self.original_grads = [p * inv_scale for p in grads]
|
|
|
|
|
|
|
|
def check_grads_unscaled(self, optimizer=None):
|
|
|
|
if optimizer is not None:
|
|
|
|
scaler = self.trainer.precision_plugin.scaler
|
|
|
|
state = scaler._per_optimizer_states[id(optimizer)]
|
|
|
|
assert state["stage"].name == "UNSCALED"
|
|
|
|
|
|
|
|
grads = [p.grad for p in self.parameters()]
|
|
|
|
assert len(grads) == len(self.original_grads)
|
|
|
|
for actual, expected in zip(grads, self.original_grads):
|
|
|
|
torch.testing.assert_allclose(actual, expected)
|
|
|
|
|
2021-10-30 10:27:49 +00:00
|
|
|
def check_grads_clipped(self):
|
|
|
|
parameters = list(self.parameters())
|
|
|
|
assert len(parameters) == len(self.clipped_parameters)
|
|
|
|
for actual, expected in zip(parameters, self.clipped_parameters):
|
|
|
|
torch.testing.assert_allclose(actual.grad, expected.grad)
|
|
|
|
|
2021-10-28 15:23:27 +00:00
|
|
|
def on_before_optimizer_step(self, optimizer, *_):
|
|
|
|
self.check_grads_unscaled(optimizer)
|
|
|
|
# manually clip
|
|
|
|
self.clipped_parameters = []
|
|
|
|
for p in self.parameters():
|
|
|
|
copy = p.detach().clone()
|
|
|
|
copy.grad = p.grad.clone()
|
|
|
|
self.clipped_parameters.append(copy)
|
|
|
|
clip_val = self.trainer.gradient_clip_val
|
|
|
|
torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val)
|
|
|
|
|
2021-10-30 10:27:49 +00:00
|
|
|
def log_grad_norm(self, grad_norm_dict):
|
|
|
|
self.check_grads_unscaled()
|
|
|
|
assert len(grad_norm_dict)
|
|
|
|
|
2021-10-28 15:23:27 +00:00
|
|
|
def configure_gradient_clipping(self, *args, **kwargs):
|
|
|
|
# let lightning clip
|
|
|
|
super().configure_gradient_clipping(*args, **kwargs)
|
|
|
|
# check clipping worked as expected
|
2021-10-30 10:27:49 +00:00
|
|
|
self.check_grads_clipped()
|
2021-10-28 15:23:27 +00:00
|
|
|
|
2021-10-30 10:27:49 +00:00
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **_):
|
|
|
|
# pass self as a kwarg
|
|
|
|
optimizer.step(closure, pl_module=self)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return TestClippingOptimizer(self.layer.parameters(), lr=0.1)
|
2021-03-03 07:56:57 +00:00
|
|
|
|
|
|
|
|
2021-09-29 13:34:26 +00:00
|
|
|
@RunIf(min_gpus=2)
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("accum", [1, 2])
|
2021-03-03 07:56:57 +00:00
|
|
|
def test_amp_gradient_unscale(tmpdir, accum: int):
|
2021-10-30 10:27:49 +00:00
|
|
|
model = TestPrecisionModel()
|
2021-03-03 07:56:57 +00:00
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
max_epochs=2,
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
limit_train_batches=2,
|
2021-10-28 15:23:27 +00:00
|
|
|
limit_val_batches=0,
|
2021-07-26 11:37:35 +00:00
|
|
|
amp_backend="native",
|
2021-10-16 15:10:25 +00:00
|
|
|
strategy="ddp_spawn",
|
2021-03-03 07:56:57 +00:00
|
|
|
gpus=2,
|
|
|
|
precision=16,
|
|
|
|
track_grad_norm=2,
|
2021-10-28 15:23:27 +00:00
|
|
|
# use a tiny value to make sure it works
|
|
|
|
gradient_clip_val=1e-3,
|
|
|
|
gradient_clip_algorithm="value",
|
2021-03-03 07:56:57 +00:00
|
|
|
log_every_n_steps=1,
|
|
|
|
accumulate_grad_batches=accum,
|
2021-10-30 10:27:49 +00:00
|
|
|
enable_progress_bar=False,
|
2021-03-03 07:56:57 +00:00
|
|
|
)
|
|
|
|
trainer.fit(model)
|
2021-04-30 17:16:28 +00:00
|
|
|
|
|
|
|
|
2021-09-29 13:34:26 +00:00
|
|
|
@RunIf(min_gpus=1)
|
2021-06-16 00:23:30 +00:00
|
|
|
def test_amp_skip_optimizer(tmpdir):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Test that optimizers can be skipped when using amp."""
|
2021-06-16 00:23:30 +00:00
|
|
|
|
|
|
|
class CustomBoringModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer1 = torch.nn.Linear(32, 32)
|
|
|
|
self.layer2 = torch.nn.Linear(32, 2)
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
x = self.layer1(x)
|
|
|
|
x = self.layer2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
return None
|
|
|
|
output = self(batch)
|
|
|
|
return self.loss(batch, output)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return [
|
|
|
|
torch.optim.SGD(self.layer1.parameters(), lr=0.1),
|
|
|
|
torch.optim.SGD(self.layer2.parameters(), lr=0.1),
|
|
|
|
]
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
trainer = Trainer(default_root_dir=tmpdir, gpus=1, fast_dev_run=1, amp_backend="native", precision=16)
|
2021-06-16 00:23:30 +00:00
|
|
|
model = CustomBoringModel()
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
2021-11-26 17:13:14 +00:00
|
|
|
@RunIf(min_gpus=2, amp_apex=True, standalone=True)
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("amp_level", ["O2"])
|
2021-04-30 17:16:28 +00:00
|
|
|
def test_amp_apex_ddp_fit(amp_level, tmpdir):
|
|
|
|
class CustomBoringModel(BoringModel):
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
assert self.layer.weight.dtype == torch.float16
|
|
|
|
assert self.trainer.precision_plugin._connected
|
|
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
fast_dev_run=True,
|
|
|
|
precision=16,
|
|
|
|
amp_backend="apex",
|
|
|
|
gpus=2,
|
2021-10-16 15:10:25 +00:00
|
|
|
strategy="ddp",
|
2021-04-30 17:16:28 +00:00
|
|
|
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
|
|
|
|
)
|
|
|
|
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
|
|
|
|
model = CustomBoringModel()
|
|
|
|
trainer.fit(model)
|
|
|
|
trainer.test(model)
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_gpus=2, amp_apex=True)
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("amp_level", ["O2"])
|
2021-04-30 17:16:28 +00:00
|
|
|
def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
fast_dev_run=True,
|
|
|
|
precision=16,
|
|
|
|
amp_backend="apex",
|
|
|
|
gpus=2,
|
2021-10-16 15:10:25 +00:00
|
|
|
strategy="ddp_spawn",
|
2021-04-30 17:16:28 +00:00
|
|
|
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
|
|
|
|
)
|
|
|
|
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
|
|
|
|
model = BoringModel()
|
|
|
|
trainer.fit(model)
|
2021-08-24 09:47:21 +00:00
|
|
|
|
|
|
|
|
2021-10-27 12:38:39 +00:00
|
|
|
@RunIf(min_torch="1.10")
|
2021-08-25 12:18:00 +00:00
|
|
|
def test_cpu_amp_precision_context_manager(tmpdir):
|
2021-10-28 14:13:53 +00:00
|
|
|
"""Test to ensure that the context manager correctly is set to CPU + bfloat16."""
|
|
|
|
plugin = NativeMixedPrecisionPlugin("bf16", "cpu")
|
|
|
|
assert plugin.device == "cpu"
|
|
|
|
assert plugin.scaler is None
|
2021-08-25 12:18:00 +00:00
|
|
|
context_manager = plugin.autocast_context_manager()
|
2021-10-25 17:33:52 +00:00
|
|
|
assert isinstance(context_manager, torch.autocast)
|
2021-10-28 14:13:53 +00:00
|
|
|
# check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786
|
|
|
|
assert str(context_manager.fast_dtype) == str(torch.bfloat16)
|
2021-08-25 12:18:00 +00:00
|
|
|
|
|
|
|
|
2021-10-20 13:25:13 +00:00
|
|
|
def test_precision_selection_raises(monkeypatch):
|
2021-08-25 12:18:00 +00:00
|
|
|
with pytest.raises(
|
2021-10-20 13:25:13 +00:00
|
|
|
MisconfigurationException, match=r"precision=16, amp_type='apex'\)` but apex AMP not supported on CPU"
|
|
|
|
):
|
|
|
|
Trainer(amp_backend="apex", precision=16)
|
|
|
|
|
|
|
|
import pytorch_lightning.plugins.precision.native_amp as amp
|
|
|
|
|
2021-10-27 12:38:39 +00:00
|
|
|
monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_1_10", False)
|
2021-10-20 13:25:13 +00:00
|
|
|
with pytest.warns(
|
|
|
|
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
|
|
|
|
), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
|
|
|
|
Trainer(precision=16)
|
|
|
|
|
|
|
|
with pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
|
|
|
|
Trainer(precision="bf16")
|
|
|
|
|
|
|
|
with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"):
|
|
|
|
Trainer(amp_backend="apex", precision="bf16")
|
|
|
|
|
|
|
|
with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises(
|
|
|
|
MisconfigurationException, match="Sharded plugins are not supported with apex"
|
|
|
|
):
|
2021-11-17 22:41:50 +00:00
|
|
|
Trainer(amp_backend="apex", precision=16, gpus=1, strategy="ddp_fully_sharded")
|
2021-10-20 13:25:13 +00:00
|
|
|
|
|
|
|
import pytorch_lightning.plugins.precision.apex_amp as apex
|
|
|
|
|
|
|
|
monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
|
|
|
|
with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises(
|
|
|
|
MisconfigurationException, match="asked for Apex AMP but you have not installed it"
|
2021-08-25 12:18:00 +00:00
|
|
|
):
|
2021-10-20 13:25:13 +00:00
|
|
|
Trainer(amp_backend="apex", precision=16, gpus=1)
|