lightning/tests/plugins/test_amp_plugin.py

87 lines
2.6 KiB
Python

from pytorch_lightning.callbacks import Callback
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
import os
from unittest import mock
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from distutils.version import LooseVersion
import torch
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, NativeAMPPlugin)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
callbacks=[CB()]
)
with pytest.raises(SystemExit):
trainer.fit(model)
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class MyNativeAMP(NativeAMPPlugin):
pass
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, MyNativeAMP)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[MyNativeAMP()],
callbacks=[CB()]
)
with pytest.raises(SystemExit):
trainer.fit(model)