lightning/tests/tests_pytorch/plugins/precision/test_amp_integration.py

57 lines
1.9 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lightning.fabric import seed_everything
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf
class FusedOptimizerParityModel(BoringModel):
def __init__(self, fused=False):
super().__init__()
self.fused = fused
def configure_optimizers(self):
assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)
@RunIf(min_torch="1.13", min_cuda_gpus=1)
def test_amp_fused_optimizer_parity(tmpdir):
def run(fused=False):
seed_everything(1234)
model = FusedOptimizerParityModel(fused)
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cuda",
devices=1,
precision="16-mixed",
max_steps=5,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
return model.parameters()
params = run(fused=False)
params_fused = run(fused=True)
# Both the regular and the fused version of Adam produce the same losses and model weights
for p, q in zip(params, params_fused):
torch.testing.assert_close(p, q)